Skip to content

Commit

Permalink
212/fix-infinite-hangs
Browse files Browse the repository at this point in the history
  • Loading branch information
alk888 committed Feb 14, 2024
1 parent 7db8d45 commit 7d0139e
Show file tree
Hide file tree
Showing 3 changed files with 64 additions and 18 deletions.
61 changes: 50 additions & 11 deletions src/client/dispatcher.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
use futures::Stream;
use rabbitmq_stream_protocol::Response;
use std::sync::{atomic::AtomicU32, Arc};
use std::sync::{atomic::{AtomicBool, AtomicU32, Ordering}, Arc};
use tracing::trace;

use dashmap::DashMap;
Expand All @@ -17,7 +17,7 @@ use super::{channel::ChannelReceiver, handler::MessageHandler};
pub(crate) struct Dispatcher<T>(DispatcherState<T>);

pub(crate) struct DispatcherState<T> {
requests: Arc<DashMap<u32, Sender<Response>>>,
requests: Arc<RequestsMap>,
correlation_id: Arc<AtomicU32>,
handler: Arc<RwLock<Option<T>>>,
}
Expand All @@ -32,13 +32,49 @@ impl<T> Clone for DispatcherState<T> {
}
}

struct RequestsMap {
requests: DashMap<u32, Sender<Response>>,
closed: AtomicBool
}

impl RequestsMap {
fn new() -> RequestsMap {
RequestsMap {
requests: DashMap::new(),
closed: AtomicBool::new(false)
}
}

fn insert(&self, correlation_id: u32, sender: Sender<Response>) -> bool {
if self.closed.load(Ordering::Relaxed) {
return false;
}
self.requests.insert(correlation_id, sender);
true
}

fn remove(&self, correlation_id: u32) -> Option<Sender<Response>> {
self.requests.remove(&correlation_id).map(|r| r.1)
}

fn close(&self) {
self.closed.store(true, Ordering::Relaxed);
self.requests.clear();
}

#[cfg(test)]
fn len(&self) -> usize {
self.requests.len()
}
}

impl<T> Dispatcher<T>
where
T: MessageHandler,
{
pub fn new() -> Dispatcher<T> {
Dispatcher(DispatcherState {
requests: Arc::new(DashMap::new()),
requests: Arc::new(RequestsMap::new()),
correlation_id: Arc::new(AtomicU32::new(0)),
handler: Arc::new(RwLock::new(None)),
})
Expand All @@ -47,23 +83,25 @@ where
#[cfg(test)]
pub fn with_handler(handler: T) -> Dispatcher<T> {
Dispatcher(DispatcherState {
requests: Arc::new(DashMap::new()),
requests: Arc::new(RequestsMap::new()),
correlation_id: Arc::new(AtomicU32::new(0)),
handler: Arc::new(RwLock::new(Some(handler))),
})
}

pub async fn response_channel(&self) -> (u32, Receiver<Response>) {
pub async fn response_channel(&self) -> Option<(u32, Receiver<Response>)> {
let (tx, rx) = channel(1);

let correlation_id = self
.0
.correlation_id
.fetch_add(1, std::sync::atomic::Ordering::Relaxed);

self.0.requests.insert(correlation_id, tx);

(correlation_id, rx)
if self.0.requests.insert(correlation_id, tx) {
Some((correlation_id, rx))
} else {
None
}
}

#[cfg(test)]
Expand All @@ -89,10 +127,10 @@ where
T: MessageHandler,
{
pub async fn dispatch(&self, correlation_id: u32, response: Response) {
let receiver = self.requests.remove(&correlation_id);
let receiver = self.requests.remove(correlation_id);

if let Some(rcv) = receiver {
let _ = rcv.1.send(response).await;
let _ = rcv.send(response).await;
}
}

Expand All @@ -103,6 +141,7 @@ where
}

pub async fn close(self, error: Option<ClientError>) {
self.requests.close();
if let Some(handler) = self.handler.read().await.as_ref() {
if let Some(err) = error {
let _ = handler.handle_message(Some(Err(err))).await;
Expand Down Expand Up @@ -265,7 +304,7 @@ mod tests {

dispatcher.start(rx).await;

let (correlation_id, mut rx) = dispatcher.response_channel().await;
let (correlation_id, mut rx) = dispatcher.response_channel().await.unwrap();

let req: Request = PeerPropertiesCommand::new(correlation_id, HashMap::new()).into();

Expand Down
19 changes: 12 additions & 7 deletions src/client/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ use std::{fs::File, io::BufReader, path::Path};
use tokio::io::AsyncRead;
use tokio::io::AsyncWrite;
use tokio::io::ReadBuf;
use tokio::{net::TcpStream, sync::Notify};
use tokio::{net::TcpStream, sync::Semaphore};
use tokio::{sync::RwLock, task::JoinHandle};
use tokio_rustls::client::TlsStream;
use tokio_rustls::rustls::ClientConfig;
Expand Down Expand Up @@ -186,7 +186,7 @@ pub struct Client {
channel: Arc<ChannelSender<SinkConnection>>,
state: Arc<RwLock<ClientState>>,
opts: ClientOptions,
tune_notifier: Arc<Notify>,
tune_notifier: Arc<Semaphore>,
publish_sequence: Arc<AtomicU64>,
}

Expand All @@ -212,7 +212,7 @@ impl Client {
opts: broker,
channel: Arc::new(sender),
state: Arc::new(RwLock::new(state)),
tune_notifier: Arc::new(Notify::new()),
tune_notifier: Arc::new(Semaphore::new(0)),
publish_sequence: Arc::new(AtomicU64::new(1)),
};

Expand Down Expand Up @@ -501,7 +501,7 @@ impl Client {
}

async fn wait_for_tune_data(&mut self) -> Result<(), ClientError> {
self.tune_notifier.notified().await;
self.tune_notifier.acquire().await.ok();
Ok(())
}

Expand Down Expand Up @@ -545,13 +545,15 @@ impl Client {
T: FromResponse,
M: FnOnce(u32) -> R,
{
let (correlation_id, mut receiver) = self.dispatcher.response_channel().await;
let Some((correlation_id, mut receiver)) = self.dispatcher.response_channel().await else {
return Err(ClientError::ConnectionClosed);
};

self.channel
.send(msg_factory(correlation_id).into())
.await?;

let response = receiver.recv().await.expect("It should contain a response");
let response = receiver.recv().await.ok_or(ClientError::ConnectionClosed)?;

self.handle_response::<T>(response).await
}
Expand Down Expand Up @@ -616,7 +618,10 @@ impl Client {
if heart_beat != 0 {
let heartbeat_interval = (heart_beat / 2).max(1);
let channel = self.channel.clone();
let tune_notifier = self.tune_notifier.clone();
let heartbeat_task = tokio::spawn(async move {
// Wait for the tunes command to be processed
tune_notifier.acquire().await.ok();
loop {
trace!("Sending heartbeat");
let _ = channel.send(HeartBeatCommand::default().into()).await;
Expand All @@ -633,7 +638,7 @@ impl Client {
.send(TunesCommand::new(max_frame_size, heart_beat).into())
.await;

self.tune_notifier.notify_one();
self.tune_notifier.close();
}

async fn handle_heart_beat_command(&self) {
Expand Down
2 changes: 2 additions & 0 deletions src/error.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@ pub enum ClientError {
GenericError(#[from] Box<dyn std::error::Error + Send + Sync>),
#[error("Client already closed")]
AlreadyClosed,
#[error("Connection closed")]
ConnectionClosed,
#[error(transparent)]
Tls(#[from] tokio_rustls::rustls::Error),
#[error("Request error: {0:?}")]
Expand Down

0 comments on commit 7d0139e

Please sign in to comment.