Skip to content

Commit

Permalink
Merge pull request #199 from Tim-Zhang/fix-over-size-limit-0.5.0
Browse files Browse the repository at this point in the history
[0.5.x] Fix the bug caused by oversized packets
  • Loading branch information
teawater authored Jul 19, 2023
2 parents 8968bfa + b195210 commit b0ca2c3
Show file tree
Hide file tree
Showing 10 changed files with 289 additions and 166 deletions.
88 changes: 48 additions & 40 deletions src/asynchronous/client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,14 @@
//

use nix::unistd::close;
use protobuf::{CodedInputStream, CodedOutputStream, Message};
use protobuf::{CodedInputStream, Message};
use std::collections::HashMap;
use std::os::unix::io::RawFd;
use std::sync::{Arc, Mutex};

use crate::common::{client_connect, MESSAGE_TYPE_RESPONSE};
use crate::common::{
check_oversize, client_connect, convert_msg_to_buf, MessageHeader, MESSAGE_TYPE_RESPONSE,
};
use crate::error::{Error, Result};
use crate::ttrpc::{Code, Request, Response};

Expand Down Expand Up @@ -101,38 +103,7 @@ impl Client {
res = receive(&mut reader) => {
match res {
Ok((header, body)) => {
tokio::spawn(async move {
let resp_tx2;
{
let mut map = req_map.lock().unwrap();
let resp_tx = match map.get(&header.stream_id) {
Some(tx) => tx,
None => {
debug!(
"Receiver got unknown packet {:?} {:?}",
header, body
);
return;
}
};

resp_tx2 = resp_tx.clone();
map.remove(&header.stream_id); // Forget the result, just remove.
}

if header.type_ != MESSAGE_TYPE_RESPONSE {
resp_tx2
.send(Err(Error::Others(format!(
"Recver got malformed packet {:?} {:?}",
header, body
))))
.await
.unwrap_or_else(|_e| error!("The request has returned"));
return;
}

resp_tx2.send(Ok(body)).await.unwrap_or_else(|_e| error!("The request has returned"));
});
spawn_trans_resp(req_map, header, body);
}
Err(e) => {
trace!("error {:?}", e);
Expand All @@ -148,12 +119,9 @@ impl Client {
}

pub async fn request(&self, req: Request) -> Result<Response> {
let mut buf = Vec::with_capacity(req.compute_size() as usize);
{
let mut s = CodedOutputStream::vec(&mut buf);
req.write_to(&mut s).map_err(err_to_others_err!(e, ""))?;
s.flush().map_err(err_to_others_err!(e, ""))?;
}
let buf = convert_msg_to_buf(&req)?;
// NOTE: pure client problem can't be rpc error, so we use false here.
check_oversize(buf.len(), false)?;

let (tx, mut rx): (ResponseSender, ResponseReceiver) = channel(100);
self.req_tx
Expand Down Expand Up @@ -202,3 +170,43 @@ impl Drop for ClientClose {
trace!("All client is droped");
}
}

// Spwan a task and transfer the response
fn spawn_trans_resp(
req_map: Arc<Mutex<HashMap<u32, ResponseSender>>>,
header: MessageHeader,
body: Result<Vec<u8>>,
) {
tokio::spawn(async move {
let resp_tx2;
{
let mut map = req_map.lock().unwrap();
let resp_tx = match map.get(&header.stream_id) {
Some(tx) => tx,
None => {
debug!("Receiver got unknown packet {:?} {:?}", header, body);
return;
}
};

resp_tx2 = resp_tx.clone();
map.remove(&header.stream_id); // Forget the result, just remove.
}

if header.type_ != MESSAGE_TYPE_RESPONSE {
resp_tx2
.send(Err(Error::Others(format!(
"Recver got malformed packet {:?}",
header
))))
.await
.unwrap_or_else(|_e| error!("The request has returned"));
return;
}

resp_tx2
.send(body)
.await
.unwrap_or_else(|_e| error!("The request has returned"));
});
}
11 changes: 8 additions & 3 deletions src/asynchronous/server.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ use std::result::Result as StdResult;
use std::sync::Arc;
use std::time::Duration;

use crate::asynchronous::stream::{receive, respond, respond_with_status};
use crate::asynchronous::stream::{receive, respond, respond_error, respond_with_status};
use crate::asynchronous::unix_incoming::UnixIncoming;
use crate::common::{self, Domain, MESSAGE_TYPE_REQUEST};
use crate::context;
Expand Down Expand Up @@ -269,16 +269,21 @@ async fn spawn_connection_handler<S>(
select! {
resp = receive(&mut reader) => {
match resp {
Ok(message) => {
Ok((mh, Ok(body))) => {
spawn(async move {
select! {
_ = handle_request(tx, fd, methods, message) => {}
_ = handle_request(tx, fd, methods, (mh, body)) => {}
_ = client_disconnected_rx2.changed() => {}
}

drop(req_done_tx2);
});
}
Ok((mh, Err(e))) => {
respond_error(tx.clone(), mh.stream_id, e).await.map_err(|e| {
error!("respond-error got error {:?}", e);
}).ok();
},
Err(e) => {
let _ = client_disconnected_tx.send(true);
trace!("error {:?}", e);
Expand Down
68 changes: 46 additions & 22 deletions src/asynchronous/stream.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,14 +3,18 @@
// SPDX-License-Identifier: Apache-2.0
//

use std::cmp;

use byteorder::{BigEndian, ByteOrder};

use crate::common::{MESSAGE_HEADER_LENGTH, MESSAGE_LENGTH_MAX, MESSAGE_TYPE_RESPONSE};
use crate::common::{
check_oversize, convert_msg_to_buf, DEFAULT_PAGE_SIZE, MESSAGE_HEADER_LENGTH,
MESSAGE_TYPE_RESPONSE,
};
use crate::error::{get_rpc_status, sock_error_msg, Error, Result};
use crate::r#async::utils;
use crate::ttrpc::{Code, Response, Status};
use crate::MessageHeader;
use protobuf::Message;
use tokio::io::AsyncReadExt;

async fn receive_count<T>(reader: &mut T, count: usize) -> Result<Vec<u8>>
Expand All @@ -25,6 +29,21 @@ where
Ok(content)
}

async fn discard_count<T>(reader: &mut T, count: usize) -> Result<()>
where
T: AsyncReadExt + std::marker::Unpin,
{
let mut need_discard = count;

while need_discard > 0 {
let once_discard = cmp::min(DEFAULT_PAGE_SIZE, need_discard);
receive_count(reader, once_discard).await?;
need_discard -= once_discard;
}

Ok(())
}

async fn receive_header<T>(reader: &mut T) -> Result<MessageHeader>
where
T: AsyncReadExt + std::marker::Unpin,
Expand All @@ -51,21 +70,17 @@ where
Ok(mh)
}

pub async fn receive<T>(reader: &mut T) -> Result<(MessageHeader, Vec<u8>)>
pub async fn receive<T>(reader: &mut T) -> Result<(MessageHeader, Result<Vec<u8>>)>
where
T: AsyncReadExt + std::marker::Unpin,
{
let mh = receive_header(reader).await?;
trace!("Got Message header {:?}", mh);

if mh.length > MESSAGE_LENGTH_MAX as u32 {
return Err(get_rpc_status(
Code::INVALID_ARGUMENT,
format!(
"message length {} exceed maximum message size of {}",
mh.length, MESSAGE_LENGTH_MAX
),
));
let mh_len = mh.length as usize;
if let Err(e) = check_oversize(mh_len, true) {
discard_count(reader, mh_len).await?;
return Ok((mh, Err(e)));
}

let buf = receive_count(reader, mh.length as usize).await?;
Expand All @@ -78,7 +93,7 @@ where
}
trace!("Got Message body {:?}", buf);

Ok((mh, buf))
Ok((mh, Ok(buf)))
}

fn header_to_buf(mh: MessageHeader) -> Vec<u8> {
Expand Down Expand Up @@ -110,15 +125,6 @@ pub fn to_res_buf(stream_id: u32, mut body: Vec<u8>) -> Vec<u8> {
buf
}

fn get_response_body(res: &Response) -> Result<Vec<u8>> {
let mut buf = Vec::with_capacity(res.compute_size() as usize);
let mut s = protobuf::CodedOutputStream::vec(&mut buf);
res.write_to(&mut s).map_err(err_to_others_err!(e, ""))?;
s.flush().map_err(err_to_others_err!(e, ""))?;

Ok(buf)
}

pub async fn respond(
tx: tokio::sync::mpsc::Sender<Vec<u8>>,
stream_id: u32,
Expand All @@ -138,7 +144,7 @@ pub async fn respond_with_status(
) -> Result<()> {
let mut res = Response::new();
res.set_status(status);
let mut body = get_response_body(&res)?;
let mut body = convert_msg_to_buf(&res)?;

let mh = MessageHeader {
length: body.len() as u32,
Expand All @@ -154,3 +160,21 @@ pub async fn respond_with_status(
.await
.map_err(err_to_others_err!(e, "Send packet to sender error "))
}

pub(crate) async fn respond_error(
tx: tokio::sync::mpsc::Sender<Vec<u8>>,
stream_id: u32,
e: Error,
) -> Result<()> {
let status = if let Error::RpcStatus(stat) = e {
stat
} else {
Status {
code: Code::UNKNOWN,
message: format!("{:?}", e),
..Default::default()
}
};

respond_with_status(tx, stream_id, status).await
}
16 changes: 10 additions & 6 deletions src/asynchronous/utils.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,11 @@
// SPDX-License-Identifier: Apache-2.0
//

use crate::common::{MessageHeader, MESSAGE_TYPE_REQUEST, MESSAGE_TYPE_RESPONSE};
use crate::error::{get_status, Error, Result};
use crate::common::{
check_oversize, convert_error_to_response, convert_msg_to_buf, MessageHeader,
MESSAGE_TYPE_REQUEST, MESSAGE_TYPE_RESPONSE,
};
use crate::error::{get_status, Result};
use crate::ttrpc::{Code, Request, Response, Status};
use async_trait::async_trait;
use protobuf::{CodedInputStream, Message};
Expand Down Expand Up @@ -96,10 +99,11 @@ pub struct TtrpcContext {
}

pub fn convert_response_to_buf(res: Response) -> Result<Vec<u8>> {
let mut buf = Vec::with_capacity(res.compute_size() as usize);
let mut s = protobuf::CodedOutputStream::vec(&mut buf);
res.write_to(&mut s).map_err(err_to_others_err!(e, ""))?;
s.flush().map_err(err_to_others_err!(e, ""))?;
let mut buf = convert_msg_to_buf(&res)?;
if let Err(e) = check_oversize(buf.len(), true) {
let resp = convert_error_to_response(e);
buf = convert_msg_to_buf(&resp)?;
};

Ok(buf)
}
Expand Down
49 changes: 47 additions & 2 deletions src/common.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,14 @@
#![allow(unused_macros)]

use crate::error::{Error, Result};
use std::os::unix::io::RawFd;

use nix::fcntl::{fcntl, FcntlArg, FdFlag, OFlag};
use nix::sys::socket::*;
use std::os::unix::io::RawFd;
use protobuf::Message;

use crate::error::{get_rpc_status, get_status, Error, Result};
use crate::ttrpc::{Code, Response};

#[derive(Debug)]
pub enum Domain {
Expand Down Expand Up @@ -155,6 +159,45 @@ pub(crate) unsafe fn client_connect(host: &str) -> Result<RawFd> {
Ok(fd)
}

pub fn check_oversize(len: usize, return_rpc_error: bool) -> Result<()> {
if len > MESSAGE_LENGTH_MAX {
let msg = format!(
"message length {} exceed maximum message size of {}",
len, MESSAGE_LENGTH_MAX
);
let e = if return_rpc_error {
get_rpc_status(Code::INVALID_ARGUMENT, msg)
} else {
Error::Others(msg)
};

return Err(e);
}

Ok(())
}

pub fn convert_msg_to_buf(msg: &impl Message) -> Result<Vec<u8>> {
let mut buf = Vec::with_capacity(msg.compute_size() as usize);
let mut s = protobuf::CodedOutputStream::vec(&mut buf);
msg.write_to(&mut s).map_err(err_to_others_err!(e, ""))?;
s.flush().map_err(err_to_others_err!(e, ""))?;

Ok(buf)
}

pub fn convert_error_to_response(e: Error) -> Response {
let status = if let Error::RpcStatus(stat) = e {
stat
} else {
get_status(Code::UNKNOWN, e)
};

let mut res = Response::new();
res.set_status(status);
res
}

macro_rules! cfg_sync {
($($item:item)*) => {
$(
Expand All @@ -180,3 +223,5 @@ pub const MESSAGE_LENGTH_MAX: usize = 4 << 20;

pub const MESSAGE_TYPE_REQUEST: u8 = 0x1;
pub const MESSAGE_TYPE_RESPONSE: u8 = 0x2;

pub const DEFAULT_PAGE_SIZE: usize = 4 << 10;
Loading

0 comments on commit b0ca2c3

Please sign in to comment.