Skip to content

Commit

Permalink
Avoid cancelling a read call (#144)
Browse files Browse the repository at this point in the history
Stream read is not cancellation safe, it can then put the
socket in a bad state and reads will get blocked. forever

Instead we run the read in own routine, and then pipe read messages
to the connection loop. which is cancellation safe.

Fixes #143
  • Loading branch information
muhamadazmy authored Jul 27, 2023
1 parent 1ab84ab commit 05f967f
Show file tree
Hide file tree
Showing 4 changed files with 39 additions and 6 deletions.
39 changes: 36 additions & 3 deletions src/peer/con.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ use futures_util::stream::StreamExt;
use std::time::Duration;
use tokio::sync::mpsc;
use tokio::time::Instant;
use tokio_tungstenite::tungstenite::Message;
use tokio_tungstenite::tungstenite::{Error, Message};
use url::Url;

const PING_INTERVAL: Duration = Duration::from_secs(20);
Expand Down Expand Up @@ -104,9 +104,10 @@ async fn retainer<S: Signer>(
}
};

let (mut write, mut read) = ws.split();
let (mut write, read) = ws.split();
let mut last = Instant::now();

let mut read = read_stream(read);
'receive: loop {
// we check here when was the last time a message was received
// from the relay. we expect to receive PONG answers (because we
Expand All @@ -129,7 +130,15 @@ async fn retainer<S: Signer>(
break 'receive;
}
},
Some(message) = read.next() => {
message = read.recv() => {
let message = match message {
None=> {
log::debug!("read stream ended") ;
break 'receive;
},
Some(message) => message,
};

// we take a note with when a message was received
last = Instant::now();
log::trace!("received a message from relay");
Expand Down Expand Up @@ -158,3 +167,27 @@ async fn retainer<S: Signer>(
log::info!("retrying connection");
}
}

fn read_stream(
mut stream: futures_util::stream::SplitStream<
tokio_tungstenite::WebSocketStream<
tokio_tungstenite::MaybeTlsStream<tokio::net::TcpStream>,
>,
>,
) -> mpsc::Receiver<Result<Message, Error>> {
let (sender, receiver) = mpsc::channel(1);
tokio::spawn(async move {
loop {
match stream.next().await {
None => return,
Some(result) => {
if sender.send(result).await.is_err() {
return;
}
}
}
}
});

receiver
}
2 changes: 1 addition & 1 deletion src/peer/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -259,7 +259,7 @@ where
// otherwise, we clear up the payload
// and set the error instead
envelope.payload = None;
let mut e = envelope.mut_error();
let e = envelope.mut_error();
e.code = err.code();
e.message = err.to_string();
}
Expand Down
2 changes: 1 addition & 1 deletion src/relay/api.rs
Original file line number Diff line number Diff line change
Expand Up @@ -411,7 +411,7 @@ impl<M: Metrics> Stream<M> {
let mut resp = Envelope::new();
resp.uid = envelope.uid;
resp.destination = Some((&self.id).into()).into();
let mut e = resp.mut_error();
let e = resp.mut_error();
e.message = err.to_string();

let bytes = match resp.write_to_bytes() {
Expand Down
2 changes: 1 addition & 1 deletion src/types/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ use std::time::{Duration, SystemTime};

include!(concat!(env!("OUT_DIR"), "/protos/mod.rs"));

pub use peer::*;
pub use peer::Backlog;
pub use types::*;

#[derive(thiserror::Error, Debug)]
Expand Down

0 comments on commit 05f967f

Please sign in to comment.