Skip to content

Commit

Permalink
async: add streaming support for server
Browse files Browse the repository at this point in the history
Added streaming support for server-side.

Signed-off-by: wllenyj <[email protected]>
  • Loading branch information
wllenyj committed Jun 10, 2022
1 parent ee5cce8 commit 8b1b1d7
Show file tree
Hide file tree
Showing 3 changed files with 181 additions and 32 deletions.
5 changes: 3 additions & 2 deletions src/asynchronous/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,10 @@ mod connection;
pub mod shutdown;
mod unix_incoming;

pub use self::stream::{Kind, StreamInner};
#[doc(inline)]
pub use crate::r#async::client::Client;
#[doc(inline)]
pub use crate::r#async::server::Server;
pub use crate::r#async::server::{Server, Service};
#[doc(inline)]
pub use utils::{MethodHandler, TtrpcContext};
pub use utils::{MethodHandler, StreamHandler, TtrpcContext};
198 changes: 168 additions & 30 deletions src/asynchronous/server.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
// Copyright 2022 Alibaba Cloud. All rights reserved.
// Copyright (c) 2020 Ant Financial
//
// SPDX-License-Identifier: Apache-2.0
Expand All @@ -10,7 +11,7 @@ use std::os::unix::io::RawFd;
use std::os::unix::io::{AsRawFd, FromRawFd};
use std::os::unix::net::UnixListener as SysUnixListener;
use std::result::Result as StdResult;
use std::sync::Arc;
use std::sync::{Arc, Mutex};
use std::time::Duration;

use async_trait::async_trait;
Expand All @@ -34,22 +35,39 @@ use crate::common::{self, Domain};
use crate::context;
use crate::error::{get_status, Error, Result};
use crate::proto::{
Code, Codec, GenMessage, Message, MessageHeader, Request, Response, Status,
MESSAGE_TYPE_REQUEST,
Code, Codec, GenMessage, Message, MessageHeader, Request, Response, Status, FLAG_NO_DATA,
FLAG_REMOTE_CLOSED, FLAG_REMOTE_OPEN, MESSAGE_TYPE_DATA, MESSAGE_TYPE_REQUEST,
};
use crate::r#async::connection::*;
use crate::r#async::shutdown;
use crate::r#async::stream::{MessageReceiver, MessageSender};
use crate::r#async::stream::{
Kind, MessageReceiver, MessageSender, ResultReceiver, ResultSender, StreamInner,
};
use crate::r#async::utils;
use crate::r#async::{MethodHandler, TtrpcContext};
use crate::r#async::{MethodHandler, StreamHandler, TtrpcContext};

const DEFAULT_CONN_SHUTDOWN_TIMEOUT: Duration = Duration::from_millis(5000);
const DEFAULT_SERVER_SHUTDOWN_TIMEOUT: Duration = Duration::from_millis(10000);

pub struct Service {
pub methods: HashMap<String, Box<dyn MethodHandler + Send + Sync>>,
pub streams: HashMap<String, Arc<dyn StreamHandler + Send + Sync>>,
}

impl Service {
pub(crate) fn get_method(&self, name: &String) -> Option<&(dyn MethodHandler + Send + Sync)> {
self.methods.get(name).map(|b| b.as_ref())
}

pub(crate) fn get_stream(&self, name: &String) -> Option<Arc<dyn StreamHandler + Send + Sync>> {
self.streams.get(name).cloned()
}
}

/// A ttrpc Server (async).
pub struct Server {
listeners: Vec<RawFd>,
methods: Arc<HashMap<String, Box<dyn MethodHandler + Send + Sync>>>,
services: Arc<HashMap<String, Service>>,
domain: Option<Domain>,

shutdown: shutdown::Notifier,
Expand All @@ -60,7 +78,7 @@ impl Default for Server {
fn default() -> Self {
Server {
listeners: Vec::with_capacity(1),
methods: Arc::new(HashMap::new()),
services: Arc::new(HashMap::new()),
domain: None,
shutdown: shutdown::with_timeout(DEFAULT_SERVER_SHUTDOWN_TIMEOUT).0,
stop_listen_tx: None,
Expand Down Expand Up @@ -105,12 +123,9 @@ impl Server {
Ok(self)
}

pub fn register_service(
mut self,
methods: HashMap<String, Box<dyn MethodHandler + Send + Sync>>,
) -> Server {
let mut_methods = Arc::get_mut(&mut self.methods).unwrap();
mut_methods.extend(methods);
pub fn register_service(mut self, new: HashMap<String, Service>) -> Server {
let services = Arc::get_mut(&mut self.services).unwrap();
services.extend(new);
self
}

Expand Down Expand Up @@ -158,7 +173,7 @@ impl Server {
I: Stream<Item = std::io::Result<S>> + Unpin + Send + 'static + AsRawFd,
S: AsyncRead + AsyncWrite + AsRawFd + Send + 'static,
{
let methods = self.methods.clone();
let services = self.services.clone();

let shutdown_waiter = self.shutdown.subscribe();

Expand All @@ -172,13 +187,13 @@ impl Server {
if let Some(conn) = conn {
// Accept a new connection
match conn {
Ok(stream) => {
let fd = stream.as_raw_fd();
Ok(conn) => {
let fd = conn.as_raw_fd();
// spawn a connection handler, would not block
spawn_connection_handler(
fd,
stream,
methods.clone(),
conn,
services.clone(),
shutdown_waiter.clone(),
).await;
}
Expand Down Expand Up @@ -244,14 +259,15 @@ impl Server {
async fn spawn_connection_handler<C>(
fd: RawFd,
conn: C,
methods: Arc<HashMap<String, Box<dyn MethodHandler + Send + Sync>>>,
services: Arc<HashMap<String, Service>>,
shutdown_waiter: shutdown::Waiter,
) where
C: AsyncRead + AsyncWrite + AsRawFd + Send + 'static,
{
let delegate = ServerBuilder {
fd,
methods,
services,
streams: Arc::new(Mutex::new(HashMap::new())),
shutdown_waiter,
};
let conn = Connection::new(conn, delegate);
Expand Down Expand Up @@ -279,7 +295,8 @@ impl AsRawFd for Server {

struct ServerBuilder {
fd: RawFd,
methods: Arc<HashMap<String, Box<dyn MethodHandler + Send + Sync>>>,
services: Arc<HashMap<String, Service>>,
streams: Arc<Mutex<HashMap<u32, ResultSender>>>,
shutdown_waiter: shutdown::Waiter,
}

Expand All @@ -296,7 +313,8 @@ impl Builder for ServerBuilder {
ServerReader {
fd: self.fd,
tx,
methods: self.methods.clone(),
services: self.services.clone(),
streams: self.streams.clone(),
server_shutdown: self.shutdown_waiter.clone(),
handler_shutdown: disconnect_notifier,
},
Expand All @@ -321,7 +339,8 @@ impl WriterDelegate for ServerWriter {
struct ServerReader {
fd: RawFd,
tx: MessageSender,
methods: Arc<HashMap<String, Box<dyn MethodHandler + Send + Sync>>>,
services: Arc<HashMap<String, Service>>,
streams: Arc<Mutex<HashMap<u32, ResultSender>>>,
server_shutdown: shutdown::Waiter,
handler_shutdown: shutdown::Notifier,
}
Expand Down Expand Up @@ -366,7 +385,8 @@ impl ServerReader {
HandlerContext {
fd: self.fd,
tx: self.tx.clone(),
methods: self.methods.clone(),
services: self.services.clone(),
streams: self.streams.clone(),
_handler_shutdown_waiter: self.handler_shutdown.subscribe(),
}
}
Expand All @@ -375,7 +395,8 @@ impl ServerReader {
struct HandlerContext {
fd: RawFd,
tx: MessageSender,
methods: Arc<HashMap<String, Box<dyn MethodHandler + Send + Sync>>>,
services: Arc<HashMap<String, Service>>,
streams: Arc<Mutex<HashMap<u32, ResultSender>>>,
// Used for waiting handler exit.
_handler_shutdown_waiter: shutdown::Waiter,
}
Expand Down Expand Up @@ -406,11 +427,62 @@ impl HandlerContext {
.ok();
}
None => {
unimplemented!();
let mut header = MessageHeader::new_data(stream_id, 0);
header.set_flags(FLAG_REMOTE_CLOSED | FLAG_NO_DATA);
let msg = GenMessage {
header,
payload: Vec::new(),
};

self.tx
.send(msg)
.await
.map_err(err_to_others_err!(e, "Send packet to sender error "))
.ok();
}
},
Err(status) => Self::respond_with_status(self.tx.clone(), stream_id, status).await,
},
MESSAGE_TYPE_DATA => {
// TODO(wllenyj): Compatible with golang behavior.
if (msg.header.flags & FLAG_REMOTE_CLOSED) == FLAG_REMOTE_CLOSED
&& !msg.payload.is_empty()
{
Self::respond_with_status(
self.tx.clone(),
stream_id,
get_status(
Code::INVALID_ARGUMENT,
format!(
"Stream id {}: data close message connot include data",
stream_id
),
),
)
.await;
}
let stream_tx = self.streams.lock().unwrap().get(&stream_id).cloned();
if let Some(stream_tx) = stream_tx {
if let Err(e) = stream_tx.send(Ok(msg)).await {
Self::respond_with_status(
self.tx.clone(),
stream_id,
get_status(
Code::INVALID_ARGUMENT,
format!("Stream id {}: handling data error: {}", stream_id, e),
),
)
.await;
}
} else {
Self::respond_with_status(
self.tx.clone(),
stream_id,
get_status(Code::INVALID_ARGUMENT, "Stream is no longger active"),
)
.await;
}
}
_ => {
// TODO: else we must ignore this for future compat. log this?
// TODO(wllenyj): Compatible with golang behavior.
Expand All @@ -432,12 +504,23 @@ impl HandlerContext {
let req = &req_msg.payload;
trace!("Got Message request {} {}", req.service, req.method);

let path = utils::get_path(&req.service, &req.method);
let method = self.methods.get(&path).ok_or_else(|| {
get_status(Code::INVALID_ARGUMENT, format!("{} does not exist", &path))
let srv = self.services.get(&req.service).ok_or_else(|| {
get_status(
Code::INVALID_ARGUMENT,
format!("{} service does not exist", &req.service),
)
})?;

return self.handle_method(method.as_ref(), req_msg).await;
if let Some(method) = srv.get_method(&req.method) {
return self.handle_method(method, req_msg).await;
}
if let Some(stream) = srv.get_stream(&req.method) {
return self.handle_stream(stream, req_msg).await;
}
Err(get_status(
Code::UNIMPLEMENTED,
format!("{} method", &req.method),
))
}

async fn handle_method(
Expand Down Expand Up @@ -484,6 +567,61 @@ impl HandlerContext {
}
}

async fn handle_stream(
&self,
stream: Arc<dyn StreamHandler + Send + Sync>,
req_msg: Message<Request>,
) -> StdResult<Option<Response>, Status> {
let stream_id = req_msg.header.stream_id;
let req = req_msg.payload;
let path = utils::get_path(&req.service, &req.method);

let (tx, rx): (ResultSender, ResultReceiver) = channel(100);
let stream_tx = tx.clone();
self.streams.lock().unwrap().insert(stream_id, tx);

let _remote_close = (req_msg.header.flags & FLAG_REMOTE_CLOSED) == FLAG_REMOTE_CLOSED;
let _remote_open = (req_msg.header.flags & FLAG_REMOTE_OPEN) == FLAG_REMOTE_OPEN;
let si = StreamInner::new(
stream_id,
self.tx.clone(),
rx,
true, // TODO
true,
Kind::Server,
self.streams.clone(),
);

let ctx = TtrpcContext {
fd: self.fd,
mh: req_msg.header,
metadata: context::from_pb(&req.metadata),
timeout_nano: req.timeout_nano,
};

let task = spawn(async move { stream.handler(ctx, si).await });

if !req.payload.is_empty() {
// Fake the first data message.
let msg = GenMessage {
header: MessageHeader::new_data(stream_id, req.payload.len() as u32),
payload: req.payload,
};
stream_tx.send(Ok(msg)).await.map_err(|e| {
error!("send stream data {} got error {:?}", path, &e);
get_status(Code::UNKNOWN, e)
})?;
}
task.await
.unwrap_or_else(|e| {
Err(Error::Others(format!(
"stream {} task got error {:?}",
path, e
)))
})
.map_err(|e| get_status(Code::UNKNOWN, e))
}

async fn respond(tx: MessageSender, stream_id: u32, resp: Response) -> Result<()> {
let payload = resp
.encode()
Expand Down
10 changes: 10 additions & 0 deletions src/asynchronous/utils.rs
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,16 @@ pub trait MethodHandler {
async fn handler(&self, ctx: TtrpcContext, req: Request) -> Result<Response>;
}

/// Trait that implements handler which is a proxy to the stream (async).
#[async_trait]
pub trait StreamHandler {
async fn handler(
&self,
ctx: TtrpcContext,
stream: crate::r#async::StreamInner,
) -> Result<Option<Response>>;
}

/// The context of ttrpc (async).
#[derive(Debug)]
pub struct TtrpcContext {
Expand Down

0 comments on commit 8b1b1d7

Please sign in to comment.