diff --git a/src/client/mod.rs b/src/client/mod.rs index b2d13b2..da9ecb5 100644 --- a/src/client/mod.rs +++ b/src/client/mod.rs @@ -25,6 +25,7 @@ use rabbitmq_stream_protocol::{ delete::Delete, delete_publisher::DeletePublisherCommand, generic::GenericResponse, + heart_beat::HeartBeatCommand, metadata::MetadataCommand, open::{OpenCommand, OpenResponse}, peer_properties::{PeerPropertiesCommand, PeerPropertiesResponse}, @@ -41,6 +42,7 @@ use rabbitmq_stream_protocol::{ types::PublishedMessage, FromResponse, Request, Response, ResponseCode, ResponseKind, }; +use tokio_native_tls::TlsStream; use tracing::trace; pub use self::handler::{MessageHandler, MessageResult}; @@ -58,14 +60,14 @@ use std::{ pin::Pin, sync::{atomic::AtomicU64, Arc}, task::{Context, Poll}, + time::{Duration, Instant}, }; use std::{future::Future, sync::atomic::Ordering}; use tokio::io::AsyncRead; use tokio::io::AsyncWrite; use tokio::io::ReadBuf; -use tokio::sync::RwLock; use tokio::{net::TcpStream, sync::Notify}; -use tokio_native_tls::TlsStream; +use tokio::{sync::RwLock, task::JoinHandle}; use tokio_util::codec::Framed; #[cfg_attr(docsrs, doc(cfg(feature = "tokio-stream")))] @@ -125,6 +127,8 @@ pub struct ClientState { handler: Option>, heartbeat: u32, max_frame_size: u32, + last_heatbeat: Instant, + heartbeat_task: Option>, } #[async_trait::async_trait] @@ -133,6 +137,7 @@ impl MessageHandler for Client { match &item { Some(Ok(response)) => match response.kind_ref() { ResponseKind::Tunes(tune) => self.handle_tune_command(tune).await, + ResponseKind::Heartbeat(_) => self.handle_heart_beat_command().await, _ => { if let Some(handler) = self.state.read().await.handler.as_ref() { let handler = handler.clone(); @@ -188,6 +193,8 @@ impl Client { handler: None, heartbeat: broker.heartbeat, max_frame_size: broker.max_frame_size, + last_heatbeat: Instant::now(), + heartbeat_task: None, }; let mut client = Client { dispatcher, @@ -228,6 +235,14 @@ impl Client { CloseRequest::new(correlation_id, ResponseCode::Ok, "Ok".to_owned()) }) .await?; + + let mut state = self.state.write().await; + + if let Some(heartbeat_task) = state.heartbeat_task.take() { + heartbeat_task.abort(); + } + + drop(state); self.channel.close().await } pub async fn subscribe( @@ -451,10 +466,10 @@ impl Client { Ok(()) } - fn max_value(&self, client: u32, server: u32) -> u32 { + fn negotiate_value(&self, client: u32, server: u32) -> u32 { match (client, server) { (client, server) if client == 0 || server == 0 => client.max(server), - (client, server) => client.max(server), + (client, server) => client.min(server), } } @@ -543,11 +558,35 @@ impl Client { async fn handle_tune_command(&self, tunes: &TunesCommand) { let mut state = self.state.write().await; - state.heartbeat = self.max_value(self.opts.heartbeat, tunes.heartbeat); - state.max_frame_size = self.max_value(self.opts.max_frame_size, tunes.max_frame_size); + state.heartbeat = self.negotiate_value(self.opts.heartbeat, tunes.heartbeat); + state.max_frame_size = self.negotiate_value(self.opts.max_frame_size, tunes.max_frame_size); let heart_beat = state.heartbeat; let max_frame_size = state.max_frame_size; + + trace!( + "Handling tune with frame size {} and heartbeat {}", + max_frame_size, + heart_beat + ); + + if let Some(task) = state.heartbeat_task.take() { + task.abort(); + } + + if heart_beat != 0 { + let heartbeat_interval = (heart_beat / 2).max(1); + let channel = self.channel.clone(); + let heartbeat_task = tokio::spawn(async move { + loop { + trace!("Sending heartbeat"); + let _ = channel.send(HeartBeatCommand::default().into()).await; + tokio::time::sleep(Duration::from_secs(heartbeat_interval.into())).await; + } + }); + state.heartbeat_task = Some(heartbeat_task); + } + drop(state); let _ = self @@ -557,4 +596,10 @@ impl Client { self.tune_notifier.notify_one(); } + + async fn handle_heart_beat_command(&self) { + trace!("Received heartbeat"); + let mut state = self.state.write().await; + state.last_heatbeat = Instant::now(); + } } diff --git a/src/environment.rs b/src/environment.rs index 20a6bae..3e140a7 100644 --- a/src/environment.rs +++ b/src/environment.rs @@ -60,7 +60,9 @@ impl Environment { /// Delete a stream pub async fn delete_stream(&self, stream: &str) -> Result<(), StreamDeleteError> { - let response = self.create_client().await?.delete_stream(stream).await?; + let client = self.create_client().await?; + let response = client.delete_stream(stream).await?; + client.close().await?; if response.is_ok() { Ok(()) @@ -122,6 +124,10 @@ impl EnvironmentBuilder { self } + pub fn heartbeat(mut self, heartbeat: u32) -> EnvironmentBuilder { + self.0.client_options.heartbeat = heartbeat; + self + } pub fn metrics_collector( mut self, collector: impl MetricsCollector + Send + Sync + 'static, diff --git a/src/producer.rs b/src/producer.rs index f49b907..840daba 100644 --- a/src/producer.rs +++ b/src/producer.rs @@ -125,6 +125,7 @@ impl ProducerBuilder { metadata.leader, stream ); + client.close().await?; client = Client::connect(ClientOptions { host: metadata.leader.host.clone(), port: metadata.leader.port as u16, @@ -553,7 +554,10 @@ impl MessageHandler for ProducerConfirmHandler { trace!(?error); // TODO clean all waiting for confirm } - None => todo!(), + None => { + trace!("Connection closed"); + // TODO connection close clean all waiting + } } Ok(()) }