diff --git a/Cargo.lock b/Cargo.lock index cda8419a6..8af4ac3c9 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -847,9 +847,9 @@ checksum = "4271d37baee1b8c7e4b708028c57d816cf9d2434acb33a549475f78c181f6253" [[package]] name = "h2" -version = "0.4.2" +version = "0.4.4" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "31d030e59af851932b72ceebadf4a2b5986dba4c3b99dd2493f8273a0f151943" +checksum = "816ec7294445779408f36fe57bc5b7fc1cf59664059096c65f905c1c61f58069" dependencies = [ "bytes", "fnv", diff --git a/rumqttc/CHANGELOG.md b/rumqttc/CHANGELOG.md index 254bd31a4..03f320ea4 100644 --- a/rumqttc/CHANGELOG.md +++ b/rumqttc/CHANGELOG.md @@ -12,6 +12,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 * `size()` method on `Packet` calculates size once serialized. * `read()` and `write()` methods on `Packet`. * `ConnectionAborted` variant on `StateError` type to denote abrupt end to a connection +* `set_session_expiry_interval` and `session_expiry_interval` methods on `MqttOptions`. ### Changed @@ -27,6 +28,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 * Validate filters while creating subscription requests. * Make v4::Connect::write return correct value +* Resume session only if broker sends `CONNACK` with `session_present == 1`. ### Security @@ -56,7 +58,7 @@ To update your code simply remove `Key::ECC()` or `Key::RSA()` from the initiali `rusttls-pemfile` to `2.0.0`, `async-tungstenite` to `0.24.0`, `ws_stream_tungstenite` to `0.12.0` and `http` to `1.0.0`. This is a breaking change as types from some of these crates are part of the public API. -- `publish` / `subscribe` / `unsubscribe` methods on `AsyncClient` and `Client` now return a `PkidPromise` which resolves into the identifier value chosen by the `EventLoop` when handling the packet. +- `publish` / `subscribe` / `unsubscribe` methods on `AsyncClient` and `Client` now return a `NoticeFuture` which is noticed after the packet is released (sent in QoS0, ACKed in QoS1, COMPed in QoS2). ### Deprecated diff --git a/rumqttc/Cargo.toml b/rumqttc/Cargo.toml index 149e49fd0..bcf0d7752 100644 --- a/rumqttc/Cargo.toml +++ b/rumqttc/Cargo.toml @@ -30,6 +30,7 @@ bytes = "1.5" log = "0.4" flume = { version = "0.11", default-features = false, features = ["async"] } thiserror = "1" +linked-hash-map = "0.5" # Optional # rustls diff --git a/rumqttc/examples/ack_notif.rs b/rumqttc/examples/ack_notif.rs index bb98599a4..970394f69 100644 --- a/rumqttc/examples/ack_notif.rs +++ b/rumqttc/examples/ack_notif.rs @@ -1,4 +1,4 @@ -use tokio::task::{self, JoinSet}; +use tokio::{task::{self, JoinSet}, time}; use rumqttc::{AsyncClient, MqttOptions, QoS}; use std::error::Error; @@ -35,24 +35,38 @@ async fn main() -> Result<(), Box> { .wait_async() .await .unwrap(); + client + .subscribe("hello/world", QoS::AtLeastOnce) + .await + .unwrap() + .wait_async() + .await + .unwrap(); + client + .subscribe("hello/world", QoS::ExactlyOnce) + .await + .unwrap() + .wait_async() + .await + .unwrap(); // Publish and spawn wait for notification let mut set = JoinSet::new(); let future = client - .publish("hello/world", QoS::AtMostOnce, false, vec![1; 1024]) + .publish("hello/world", QoS::AtMostOnce, false, vec![1; 1]) .await .unwrap(); set.spawn(future.wait_async()); let future = client - .publish("hello/world", QoS::AtLeastOnce, false, vec![1; 1024]) + .publish("hello/world", QoS::AtLeastOnce, false, vec![1; 2]) .await .unwrap(); set.spawn(future.wait_async()); let future = client - .publish("hello/world", QoS::ExactlyOnce, false, vec![1; 1024]) + .publish("hello/world", QoS::ExactlyOnce, false, vec![1; 3]) .await .unwrap(); set.spawn(future.wait_async()); @@ -61,5 +75,6 @@ async fn main() -> Result<(), Box> { println!("Acknoledged = {:?}", res?); } + time::sleep(Duration::from_secs(6)).await; Ok(()) } diff --git a/rumqttc/examples/ack_notif_v5.rs b/rumqttc/examples/ack_notif_v5.rs new file mode 100644 index 000000000..78e99ef8f --- /dev/null +++ b/rumqttc/examples/ack_notif_v5.rs @@ -0,0 +1,80 @@ +use tokio::{task::{self, JoinSet}, time}; + +use rumqttc::v5::{AsyncClient, MqttOptions, mqttbytes::QoS}; +use std::error::Error; +use std::time::Duration; + +#[tokio::main(flavor = "current_thread")] +async fn main() -> Result<(), Box> { + pretty_env_logger::init(); + // color_backtrace::install(); + + let mut mqttoptions = MqttOptions::new("test-1", "localhost", 1883); + mqttoptions.set_keep_alive(Duration::from_secs(5)); + + let (client, mut eventloop) = AsyncClient::new(mqttoptions, 10); + task::spawn(async move { + loop { + let event = eventloop.poll().await; + match &event { + Ok(v) => { + println!("Event = {v:?}"); + } + Err(e) => { + println!("Error = {e:?}"); + } + } + } + }); + + // Subscribe and wait for broker acknowledgement + client + .subscribe("hello/world", QoS::AtMostOnce) + .await + .unwrap() + .wait_async() + .await + .unwrap(); + client + .subscribe("hello/world", QoS::AtLeastOnce) + .await + .unwrap() + .wait_async() + .await + .unwrap(); + client + .subscribe("hello/world", QoS::ExactlyOnce) + .await + .unwrap() + .wait_async() + .await + .unwrap(); + + // Publish and spawn wait for notification + let mut set = JoinSet::new(); + + let future = client + .publish("hello/world", QoS::AtMostOnce, false, vec![1; 1]) + .await + .unwrap(); + set.spawn(future.wait_async()); + + let future = client + .publish("hello/world", QoS::AtLeastOnce, false, vec![1; 2]) + .await + .unwrap(); + set.spawn(future.wait_async()); + + let future = client + .publish("hello/world", QoS::ExactlyOnce, false, vec![1; 3]) + .await + .unwrap(); + set.spawn(future.wait_async()); + + while let Some(res) = set.join_next().await { + println!("Acknoledged = {:?}", res?); + } + + time::sleep(Duration::from_secs(6)).await; + Ok(()) +} diff --git a/rumqttc/examples/async_manual_acks_v5.rs b/rumqttc/examples/async_manual_acks_v5.rs index bcf1bf356..7597a6164 100644 --- a/rumqttc/examples/async_manual_acks_v5.rs +++ b/rumqttc/examples/async_manual_acks_v5.rs @@ -10,6 +10,7 @@ fn create_conn() -> (AsyncClient, EventLoop) { let mut mqttoptions = MqttOptions::new("test-1", "localhost", 1884); mqttoptions .set_keep_alive(Duration::from_secs(5)) + .set_session_expiry_interval(u32::MAX.into()) .set_manual_acks(true) .set_clean_start(false); diff --git a/rumqttc/examples/pkid_promise.rs b/rumqttc/examples/pkid_promise.rs deleted file mode 100644 index 0fefd093e..000000000 --- a/rumqttc/examples/pkid_promise.rs +++ /dev/null @@ -1,70 +0,0 @@ -use futures_util::stream::StreamExt; -use tokio::{ - select, - task::{self, JoinSet}, -}; -use tokio_util::time::DelayQueue; - -use rumqttc::{AsyncClient, MqttOptions, QoS}; -use std::error::Error; -use std::time::Duration; - -#[tokio::main(flavor = "current_thread")] -async fn main() -> Result<(), Box> { - pretty_env_logger::init(); - // color_backtrace::install(); - - let mut mqttoptions = MqttOptions::new("test-1", "broker.emqx.io", 1883); - mqttoptions.set_keep_alive(Duration::from_secs(5)); - - let (client, mut eventloop) = AsyncClient::new(mqttoptions, 10); - task::spawn(async move { - requests(client).await; - }); - - loop { - let event = eventloop.poll().await; - match &event { - Ok(v) => { - println!("Event = {v:?}"); - } - Err(e) => { - println!("Error = {e:?}"); - return Ok(()); - } - } - } -} - -async fn requests(client: AsyncClient) { - let mut joins = JoinSet::new(); - joins.spawn( - client - .subscribe("hello/world", QoS::AtMostOnce) - .await - .unwrap() - .wait_async(), - ); - - let mut queue = DelayQueue::new(); - for i in 1..=10 { - queue.insert(i as usize, Duration::from_secs(i)); - } - - loop { - select! { - Some(i) = queue.next() => { - joins.spawn( - client - .publish("hello/world", QoS::ExactlyOnce, false, vec![1; i.into_inner()]) - .await - .unwrap().wait_async(), - ); - } - Some(Ok(Ok(pkid))) = joins.join_next() => { - println!("Pkid: {:?}", pkid); - } - else => break, - } - } -} diff --git a/rumqttc/examples/pkid_promise_v5.rs b/rumqttc/examples/pkid_promise_v5.rs deleted file mode 100644 index e4bd3d31c..000000000 --- a/rumqttc/examples/pkid_promise_v5.rs +++ /dev/null @@ -1,70 +0,0 @@ -use futures_util::stream::StreamExt; -use tokio::{ - select, - task::{self, JoinSet}, -}; -use tokio_util::time::DelayQueue; - -use rumqttc::v5::{mqttbytes::QoS, AsyncClient, MqttOptions}; -use std::error::Error; -use std::time::Duration; - -#[tokio::main(flavor = "current_thread")] -async fn main() -> Result<(), Box> { - pretty_env_logger::init(); - // color_backtrace::install(); - - let mut mqttoptions = MqttOptions::new("test-1", "broker.emqx.io", 1883); - mqttoptions.set_keep_alive(Duration::from_secs(5)); - - let (client, mut eventloop) = AsyncClient::new(mqttoptions, 10); - task::spawn(async move { - requests(client).await; - }); - - loop { - let event = eventloop.poll().await; - match &event { - Ok(v) => { - println!("Event = {v:?}"); - } - Err(e) => { - println!("Error = {e:?}"); - return Ok(()); - } - } - } -} - -async fn requests(client: AsyncClient) { - let mut joins = JoinSet::new(); - joins.spawn( - client - .subscribe("hello/world", QoS::AtMostOnce) - .await - .unwrap() - .wait_async(), - ); - - let mut queue = DelayQueue::new(); - for i in 1..=10 { - queue.insert(i as usize, Duration::from_secs(i)); - } - - loop { - select! { - Some(i) = queue.next() => { - joins.spawn( - client - .publish("hello/world", QoS::ExactlyOnce, false, vec![1; i.into_inner()]) - .await - .unwrap().wait_async(), - ); - } - Some(Ok(Ok(pkid))) = joins.join_next() => { - println!("Pkid: {:?}", pkid); - } - else => break, - } - } -} diff --git a/rumqttc/src/eventloop.rs b/rumqttc/src/eventloop.rs index 5687c6d71..237cfbc28 100644 --- a/rumqttc/src/eventloop.rs +++ b/rumqttc/src/eventloop.rs @@ -2,6 +2,7 @@ use crate::notice::NoticeTx; use crate::{framed::Network, Transport}; use crate::{Incoming, MqttState, NetworkOptions, Packet, Request, StateError}; use crate::{MqttOptions, Outgoing}; +use crate::NoticeError; use crate::framed::AsyncReadWrite; use crate::mqttbytes::v4::*; @@ -150,18 +151,29 @@ impl EventLoop { Ok(inner) => inner?, Err(_) => return Err(ConnectionError::NetworkTimeout), }; + // Last session might contain packets which aren't acked. If it's a new session, clear the pending packets. + if !connack.session_present { + for (tx, request) in self.pending.drain(..) { + // If the request is a publish request, send an error to the future that is waiting for the ack. + if let Request::Publish(_) = request { + tx.error(NoticeError::SessionReset) + } + } + } self.network = Some(network); if self.keepalive_timeout.is_none() && !self.mqtt_options.keep_alive.is_zero() { self.keepalive_timeout = Some(Box::pin(time::sleep(self.mqtt_options.keep_alive))); } - return Ok(Event::Incoming(connack)); + return Ok(Event::Incoming(Packet::ConnAck(connack))); } match self.select().await { Ok(v) => Ok(v), Err(e) => { + // MQTT requires that packets pending acknowledgement should be republished on session resume. + // Move pending messages from state to eventloop. self.clean(); Err(e) } @@ -296,14 +308,14 @@ impl EventLoop { async fn connect( mqtt_options: &MqttOptions, network_options: NetworkOptions, -) -> Result<(Network, Incoming), ConnectionError> { +) -> Result<(Network, ConnAck), ConnectionError> { // connect to the broker let mut network = network_connect(mqtt_options, network_options).await?; // make MQTT connection request (which internally awaits for ack) - let packet = mqtt_connect(mqtt_options, &mut network).await?; + let connack = mqtt_connect(mqtt_options, &mut network).await?; - Ok((network, packet)) + Ok((network, connack)) } pub(crate) async fn socket_connect( @@ -471,7 +483,7 @@ async fn network_connect( async fn mqtt_connect( options: &MqttOptions, network: &mut Network, -) -> Result { +) -> Result { let keep_alive = options.keep_alive().as_secs() as u16; let clean_session = options.clean_session(); let last_will = options.last_will(); @@ -488,7 +500,7 @@ async fn mqtt_connect( // validate connack match network.read().await? { Incoming::ConnAck(connack) if connack.code == ConnectReturnCode::Success => { - Ok(Packet::ConnAck(connack)) + Ok(connack) } Incoming::ConnAck(connack) => Err(ConnectionError::ConnectionRefused(connack.code)), packet => Err(ConnectionError::NotConnAck(packet)), diff --git a/rumqttc/src/lib.rs b/rumqttc/src/lib.rs index e547a3aee..017ec18e8 100644 --- a/rumqttc/src/lib.rs +++ b/rumqttc/src/lib.rs @@ -141,8 +141,7 @@ pub use client::{ pub use eventloop::{ConnectionError, Event, EventLoop}; pub use mqttbytes::v4::*; pub use mqttbytes::*; -use notice::NoticeTx; -pub use notice::{NoticeError, NoticeFuture}; +pub use notice::{NoticeTx, NoticeError, NoticeFuture}; #[cfg(feature = "use-rustls")] use rustls_native_certs::load_native_certs; pub use state::{MqttState, StateError}; diff --git a/rumqttc/src/notice.rs b/rumqttc/src/notice.rs index 3896e5c4d..cc7ee1e39 100644 --- a/rumqttc/src/notice.rs +++ b/rumqttc/src/notice.rs @@ -1,7 +1,8 @@ use tokio::sync::oneshot; use crate::{ - v5::mqttbytes::v5::{SubscribeReasonCode as V5SubscribeReasonCode, UnsubAckReason}, + v5::mqttbytes::v5::{SubscribeReasonCode as V5SubscribeReasonCode, UnsubAckReason, + PubAckReason, PubRecReason, PubCompReason}, SubscribeReasonCode, }; @@ -15,6 +16,14 @@ pub enum NoticeError { V5Subscribe(V5SubscribeReasonCode), #[error(" v5 Unsubscription Failure Reason: {0:?}")] V5Unsubscribe(UnsubAckReason), + #[error(" v5 Publish Ack Failure Reason Code: {0:?}")] + V5PubAck(PubAckReason), + #[error(" v5 Publish Rec Failure Reason Code: {0:?}")] + V5PubRec(PubRecReason), + #[error(" v5 Publish Comp Failure Reason Code: {0:?}")] + V5PubComp(PubCompReason), + #[error(" Dropped due to session reconnect with previous state expire/lost")] + SessionReset, } impl From for NoticeError { diff --git a/rumqttc/src/state.rs b/rumqttc/src/state.rs index 90ff77acf..220fa9a18 100644 --- a/rumqttc/src/state.rs +++ b/rumqttc/src/state.rs @@ -3,7 +3,7 @@ use crate::{Event, Incoming, NoticeError, Outgoing, Request}; use crate::mqttbytes::v4::*; use crate::mqttbytes::{self, *}; -use std::collections::{HashMap, VecDeque}; +use std::collections::VecDeque; use std::{io, time::Instant}; /// Errors during state handling @@ -63,14 +63,16 @@ pub struct MqttState { /// Maximum number of allowed inflight pub(crate) max_inflight: u16, /// Outgoing QoS 1, 2 publishes which aren't acked yet - pub(crate) outgoing_pub: HashMap, + pub(crate) outgoing_pub1: VecDeque<(Publish, NoticeTx)>, + pub(crate) outgoing_pub2: VecDeque<(Publish, NoticeTx)>, /// Packet ids of released QoS 2 publishes - pub(crate) outgoing_rel: HashMap, + pub(crate) outgoing_rel: VecDeque<(u16, NoticeTx)>, /// Packet ids on incoming QoS 2 publishes pub(crate) incoming_pub: Vec>, - - outgoing_sub: HashMap, - outgoing_unsub: HashMap, + /// Outgoing subscribes + pub(crate) outgoing_sub: VecDeque<(u16, NoticeTx)>, + /// Outgoing unsubscribes + pub(crate) outgoing_unsub: VecDeque<(u16, NoticeTx)>, /// Last collision due to broker not acking in order pub(crate) collision: Option<(Publish, NoticeTx)>, @@ -95,11 +97,12 @@ impl MqttState { inflight: 0, max_inflight, // index 0 is wasted as 0 is not a valid packet id - outgoing_pub: HashMap::new(), - outgoing_rel: HashMap::new(), - incoming_pub: vec![None; std::u16::MAX as usize + 1], - outgoing_sub: HashMap::new(), - outgoing_unsub: HashMap::new(), + outgoing_pub1: VecDeque::with_capacity(10), + outgoing_pub2: VecDeque::with_capacity(10), + outgoing_rel: VecDeque::with_capacity(10), + incoming_pub: vec![None; u16::MAX as usize + 1], + outgoing_sub: VecDeque::with_capacity(10), + outgoing_unsub: VecDeque::with_capacity(10), collision: None, // TODO: Optimize these sizes later events: VecDeque::with_capacity(100), @@ -111,13 +114,27 @@ impl MqttState { pub fn clean(&mut self) -> Vec<(NoticeTx, Request)> { let mut pending = Vec::with_capacity(100); - for (_, (publish, tx)) in self.outgoing_pub.drain() { - let request = Request::Publish(publish); - pending.push((tx, request)); + for outgoing_pub in [&mut self.outgoing_pub1, &mut self.outgoing_pub2] { + let mut second_half = Vec::with_capacity(100); + let mut last_pkid_found = false; + for (publish, tx) in outgoing_pub.drain(..) { + let this_pkid = publish.pkid; + let request = Request::Publish(publish); + if !last_pkid_found { + second_half.push((tx, request)); + if this_pkid == self.last_puback { + last_pkid_found = true; + } + } else { + pending.push((tx, request)); + } + } + pending.extend(second_half); } + println!("pending = {:?}", pending.len()); // remove and collect pending releases - for (pkid, tx) in self.outgoing_rel.drain() { + for (pkid, tx) in self.outgoing_rel.drain(..) { let request = Request::PubRel(PubRel::new(pkid)); pending.push((tx, request)); } @@ -193,11 +210,14 @@ impl MqttState { error!("Unsolicited suback packet: {:?}", suback.pkid); return Err(StateError::Unsolicited(suback.pkid)); } - - let tx = self + // No expecting ordered acks for suback + // Search outgoing_sub to find the right suback.pkid + let pos = self .outgoing_sub - .remove(&suback.pkid) + .iter() + .position(|(pkid, _)| *pkid == suback.pkid) .ok_or(StateError::Unsolicited(suback.pkid))?; + let (_, tx) = self.outgoing_sub.remove(pos).unwrap(); for reason in suback.return_codes.iter() { match reason { @@ -223,10 +243,15 @@ impl MqttState { error!("Unsolicited unsuback packet: {:?}", unsuback.pkid); return Err(StateError::Unsolicited(unsuback.pkid)); } - self.outgoing_sub - .remove(&unsuback.pkid) - .ok_or(StateError::Unsolicited(unsuback.pkid))? - .success(); + // No expecting ordered acks for unsuback + // Search outgoing_unsub to find the right suback.pkid + let pos = self + .outgoing_unsub + .iter() + .position(|(pkid, _)| *pkid == unsuback.pkid) + .ok_or(StateError::Unsolicited(unsuback.pkid))?; + let (_, tx) = self.outgoing_unsub.remove(pos).unwrap(); + tx.success(); Ok(None) } @@ -263,10 +288,11 @@ impl MqttState { error!("Unsolicited puback packet: {:?}", puback.pkid); return Err(StateError::Unsolicited(puback.pkid)); } - + // Expecting ordered acks for puback + // Check front of outgoing_pub to see if it's in order let (_, tx) = self - .outgoing_pub - .remove(&puback.pkid) + .outgoing_pub1 + .pop_front() .ok_or(StateError::Unsolicited(puback.pkid))?; tx.success(); @@ -274,8 +300,11 @@ impl MqttState { self.inflight -= 1; let packet = self.check_collision(puback.pkid).map(|(publish, tx)| { - self.outgoing_pub - .insert(publish.pkid, (publish.clone(), tx)); + if publish.qos == QoS::AtLeastOnce { + self.outgoing_pub1.push_back((publish.clone(), tx)); + } else if publish.qos == QoS::AtMostOnce { + self.outgoing_pub2.push_back((publish.clone(), tx)); + } self.inflight += 1; let event = Event::Outgoing(Outgoing::Publish(publish.pkid)); @@ -294,13 +323,19 @@ impl MqttState { return Err(StateError::Unsolicited(pubrec.pkid)); } - let (_, tx) = self - .outgoing_pub - .remove(&pubrec.pkid) + // Expecting ordered acks for pubrec + // Check front of outgoing_pub to see if it's in order + let (publish, tx) = self + .outgoing_pub2 + .pop_front() .ok_or(StateError::Unsolicited(pubrec.pkid))?; + if publish.pkid != pubrec.pkid { + error!("Unsolicited pubrec packet: {:?}", pubrec.pkid); + return Err(StateError::Unsolicited(pubrec.pkid)); + } // NOTE: Inflight - 1 for qos2 in comp - self.outgoing_rel.insert(pubrec.pkid, tx); + self.outgoing_rel.push_back((pubrec.pkid, tx)); let pubrel = PubRel { pkid: pubrec.pkid }; let event = Event::Outgoing(Outgoing::PubRel(pubrec.pkid)); self.events.push_back(event); @@ -331,15 +366,25 @@ impl MqttState { error!("Unsolicited pubcomp packet: {:?}", pubcomp.pkid); return Err(StateError::Unsolicited(pubcomp.pkid)); } - self.outgoing_rel - .remove(&pubcomp.pkid) - .ok_or(StateError::Unsolicited(pubcomp.pkid))? - .success(); + // Expecting ordered acks for pubcomp + // Check front of outgoing_pub to see if it's in order + let (pkid, tx) = self + .outgoing_rel + .pop_front() + .ok_or(StateError::Unsolicited(pubcomp.pkid))?; + if pkid != pubcomp.pkid { + error!("Unsolicited pubcomp packet: {:?}", pubcomp.pkid); + return Err(StateError::Unsolicited(pubcomp.pkid)); + } + tx.success(); self.inflight -= 1; let packet = self.check_collision(pubcomp.pkid).map(|(publish, tx)| { - self.outgoing_pub - .insert(pubcomp.pkid, (publish.clone(), tx)); + if publish.qos == QoS::AtLeastOnce { + self.outgoing_pub1.push_back((publish.clone(), tx)); + } else if publish.qos == QoS::AtMostOnce { + self.outgoing_pub2.push_back((publish.clone(), tx)); + } let event = Event::Outgoing(Outgoing::Publish(publish.pkid)); self.events.push_back(event); self.collision_ping_count = 0; @@ -370,7 +415,13 @@ impl MqttState { let pkid = publish.pkid; - if self.outgoing_pub.get(&publish.pkid).is_some() { + let outgoing_pub = match publish.qos { + QoS::AtLeastOnce => &mut self.outgoing_pub1, + QoS::ExactlyOnce => &mut self.outgoing_pub2, + _ => unreachable!(), + }; + if let Some(pos) = outgoing_pub.iter().position(|(publish, _)| publish.pkid == pkid) { + outgoing_pub.get(pos); info!("Collision on packet id = {:?}", publish.pkid); self.collision = Some((publish, notice_tx)); let event = Event::Outgoing(Outgoing::AwaitAck(pkid)); @@ -380,7 +431,7 @@ impl MqttState { // if there is an existing publish at this pkid, this implies that broker hasn't acked this // packet yet. This error is possible only when broker isn't acking sequentially - self.outgoing_pub.insert(pkid, (publish.clone(), notice_tx)); + outgoing_pub.push_back((publish.clone(), notice_tx)); self.inflight += 1; } else { notice_tx.success() @@ -479,7 +530,7 @@ impl MqttState { subscription.filters, subscription.pkid ); - self.outgoing_sub.insert(pkid, notice_tx); + self.outgoing_sub.push_back((pkid, notice_tx)); let event = Event::Outgoing(Outgoing::Subscribe(subscription.pkid)); self.events.push_back(event); @@ -499,7 +550,7 @@ impl MqttState { unsub.topics, unsub.pkid ); - self.outgoing_unsub.insert(pkid, notice_tx); + self.outgoing_unsub.push_back((pkid, notice_tx)); let event = Event::Outgoing(Outgoing::Unsubscribe(unsub.pkid)); self.events.push_back(event); @@ -539,8 +590,7 @@ impl MqttState { _ => pubrel, }; - self.outgoing_rel.insert(pubrel.pkid, notice_tx); - self.inflight += 1; + self.outgoing_rel.push_back((pubrel.pkid, notice_tx)); Ok(pubrel) } @@ -566,7 +616,7 @@ impl MqttState { #[cfg(test)] mod test { - use std::collections::HashMap; + use std::collections::VecDeque; use super::{MqttState, StateError}; use crate::mqttbytes::v4::*; @@ -750,11 +800,18 @@ mod test { mqtt.handle_incoming_puback(&PubAck::new(1)).unwrap(); assert_eq!(mqtt.inflight, 1); - mqtt.handle_incoming_puback(&PubAck::new(2)).unwrap(); + mqtt.handle_incoming_pubrec(&PubRec::new(2)).unwrap(); + assert_eq!(mqtt.inflight, 1); + + let (tx, _) = NoticeTx::new(); + mqtt.outgoing_pubrel(PubRel::new(2), tx).unwrap(); + assert_eq!(mqtt.inflight, 1); + + mqtt.handle_incoming_pubcomp(&PubComp::new(2)).unwrap(); assert_eq!(mqtt.inflight, 0); - assert!(mqtt.outgoing_pub.get(&1).is_none()); - assert!(mqtt.outgoing_pub.get(&2).is_none()); + assert!(mqtt.outgoing_pub1.get(0).is_none()); + assert!(mqtt.outgoing_pub2.get(0).is_none()); } #[test] @@ -785,11 +842,11 @@ mod test { assert_eq!(mqtt.inflight, 2); // check if the remaining element's pkid is 1 - let (backup, _) = mqtt.outgoing_pub.get(&1).unwrap(); + let (backup, _) = mqtt.outgoing_pub1.get(0).unwrap(); assert_eq!(backup.pkid, 1); // check if the qos2 element's release pkik has been set - assert!(mqtt.outgoing_rel.get(&2).is_some()); + assert!(mqtt.outgoing_rel.get(0).is_some()); } #[test] @@ -885,11 +942,11 @@ mod test { fn clean_is_calculating_pending_correctly() { let mut mqtt = build_mqttstate(); - fn build_outgoing_pub() -> HashMap { - let mut outgoing_pub = HashMap::new(); + fn build_outgoing_pub() -> VecDeque<(Publish, NoticeTx)> { + let mut outgoing_pub = VecDeque::new(); let (tx, _) = NoticeTx::new(); - outgoing_pub.insert( - 2, + outgoing_pub.push_back( + // 2, ( Publish { dup: false, @@ -903,8 +960,8 @@ mod test { ), ); let (tx, _) = NoticeTx::new(); - outgoing_pub.insert( - 3, + outgoing_pub.push_back( + // 3, ( Publish { dup: false, @@ -918,8 +975,8 @@ mod test { ), ); let (tx, _) = NoticeTx::new(); - outgoing_pub.insert( - 4, + outgoing_pub.push_back( + // 4, ( Publish { dup: false, @@ -933,8 +990,8 @@ mod test { ), ); let (tx, _) = NoticeTx::new(); - outgoing_pub.insert( - 7, + outgoing_pub.push_back( + // 7, ( Publish { dup: false, @@ -951,7 +1008,7 @@ mod test { outgoing_pub } - mqtt.outgoing_pub = build_outgoing_pub(); + mqtt.outgoing_pub1 = build_outgoing_pub(); mqtt.last_puback = 3; let requests = mqtt.clean(); let res = vec![6, 1, 2, 3]; @@ -963,7 +1020,7 @@ mod test { } } - mqtt.outgoing_pub = build_outgoing_pub(); + mqtt.outgoing_pub1 = build_outgoing_pub(); mqtt.last_puback = 0; let requests = mqtt.clean(); let res = vec![1, 2, 3, 6]; @@ -975,7 +1032,7 @@ mod test { } } - mqtt.outgoing_pub = build_outgoing_pub(); + mqtt.outgoing_pub1 = build_outgoing_pub(); mqtt.last_puback = 6; let requests = mqtt.clean(); let res = vec![1, 2, 3, 6]; diff --git a/rumqttc/src/v5/eventloop.rs b/rumqttc/src/v5/eventloop.rs index 2c24edbfb..0db89296d 100644 --- a/rumqttc/src/v5/eventloop.rs +++ b/rumqttc/src/v5/eventloop.rs @@ -3,7 +3,7 @@ use super::mqttbytes::v5::*; use super::{Incoming, MqttOptions, MqttState, Outgoing, Request, StateError, Transport}; use crate::eventloop::socket_connect; use crate::framed::AsyncReadWrite; -use crate::notice::NoticeTx; +use crate::notice::{NoticeTx, NoticeError}; use flume::{bounded, Receiver, Sender}; use tokio::select; @@ -142,18 +142,29 @@ impl EventLoop { connect(&mut self.options), ) .await??; + // Last session might contain packets which aren't acked. If it's a new session, clear the pending packets. + if !connack.session_present { + for (tx, request) in self.pending.drain(..) { + // If the request is a publish request, send an error to the future that is waiting for the ack. + if let Request::Publish(_) = request { + tx.error(NoticeError::SessionReset) + } + } + } self.network = Some(network); if self.keepalive_timeout.is_none() { self.keepalive_timeout = Some(Box::pin(time::sleep(self.options.keep_alive))); } - self.state.handle_incoming_packet(connack)?; + self.state.handle_incoming_packet(Packet::ConnAck(connack))?; } match self.select().await { Ok(v) => Ok(v), Err(e) => { + // MQTT requires that packets pending acknowledgement should be republished on session resume. + // Move pending messages from state to eventloop. self.clean(); Err(e) } @@ -210,7 +221,9 @@ impl EventLoop { self.options.pending_throttle ), if !self.pending.is_empty() || (!inflight_full && !collision) => match o { Ok((tx, request)) => { - self.state.handle_outgoing_packet(tx, request)?; + if let Some(outgoing) = self.state.handle_outgoing_packet(tx, request)? { + network.write(outgoing).await?; + } network.flush().await?; Ok(self.state.events.pop_front().unwrap()) } @@ -230,7 +243,9 @@ impl EventLoop { timeout.as_mut().reset(Instant::now() + self.options.keep_alive); let (tx, _) = NoticeTx::new(); - self.state.handle_outgoing_packet(tx, Request::PingReq)?; + if let Some(outgoing) = self.state.handle_outgoing_packet(tx, Request::PingReq)? { + network.write(outgoing).await?; + } network.flush().await?; Ok(self.state.events.pop_front().unwrap()) } @@ -261,19 +276,19 @@ impl EventLoop { /// the stream. /// This function (for convenience) includes internal delays for users to perform internal sleeps /// between re-connections so that cancel semantics can be used during this sleep -async fn connect(options: &mut MqttOptions) -> Result<(Network, Incoming), ConnectionError> { +async fn connect(options: &mut MqttOptions) -> Result<(Network, ConnAck), ConnectionError> { // connect to the broker let mut network = network_connect(options).await?; // make MQTT connection request (which internally awaits for ack) - let packet = mqtt_connect(options, &mut network).await?; + let connack = mqtt_connect(options, &mut network).await?; // Last session might contain packets which aren't acked. MQTT says these packets should be // republished in the next session // move pending messages from state to eventloop // let pending = self.state.clean(); // self.pending = pending.into_iter(); - Ok((network, packet)) + Ok((network, connack)) } async fn network_connect(options: &MqttOptions) -> Result { @@ -385,7 +400,7 @@ async fn network_connect(options: &MqttOptions) -> Result Result { +) -> Result { let keep_alive = options.keep_alive().as_secs() as u16; let clean_start = options.clean_start(); let client_id = options.client_id(); @@ -404,14 +419,18 @@ async fn mqtt_connect( // validate connack match network.read().await? { Incoming::ConnAck(connack) if connack.code == ConnectReturnCode::Success => { - // Override local keep_alive value if set by server. if let Some(props) = &connack.properties { if let Some(keep_alive) = props.server_keep_alive { options.keep_alive = Duration::from_secs(keep_alive as u64); } network.set_max_outgoing_size(props.max_packet_size); + + // Override local session_expiry_interval value if set by server. + if props.session_expiry_interval.is_some() { + options.set_session_expiry_interval(props.session_expiry_interval); + } } - Ok(Packet::ConnAck(connack)) + Ok(connack) } Incoming::ConnAck(connack) => Err(ConnectionError::ConnectionRefused(connack.code)), packet => Err(ConnectionError::NotConnAck(Box::new(packet))), diff --git a/rumqttc/src/v5/mod.rs b/rumqttc/src/v5/mod.rs index 4184ca46b..00ad6df68 100644 --- a/rumqttc/src/v5/mod.rs +++ b/rumqttc/src/v5/mod.rs @@ -33,7 +33,7 @@ pub type Incoming = Packet; /// Requests by the client to mqtt event loop. Request are /// handled one by one. -#[derive(Clone, Debug, Eq, PartialEq)] +#[derive(Debug)] pub enum Request { Publish(Publish), PubAck(PubAck), @@ -49,6 +49,47 @@ pub enum Request { Disconnect, } +impl Clone for Request { + fn clone(&self) -> Self { + match self { + Self::Publish(p) => Self::Publish(p.clone()), + Self::PubAck(p) => Self::PubAck(p.clone()), + Self::PubRec(p) => Self::PubRec(p.clone()), + Self::PubRel(p) => Self::PubRel(p.clone()), + Self::PubComp(p) => Self::PubComp(p.clone()), + Self::Subscribe(p) => Self::Subscribe(p.clone()), + Self::SubAck(p) => Self::SubAck(p.clone()), + Self::PingReq => Self::PingReq, + Self::PingResp => Self::PingResp, + Self::Disconnect => Self::Disconnect, + Self::Unsubscribe(p) => Self::Unsubscribe(p.clone()), + Self::UnsubAck(p) => Self::UnsubAck(p.clone()), + } + } +} + +impl PartialEq for Request { + fn eq(&self, other: &Self) -> bool { + match (self, other) { + (Self::Publish(p1), Self::Publish(p2)) => p1 == p2, + (Self::PubAck(p1), Self::PubAck(p2)) => p1 == p2, + (Self::PubRec(p1), Self::PubRec(p2)) => p1 == p2, + (Self::PubRel(p1), Self::PubRel(p2)) => p1 == p2, + (Self::PubComp(p1), Self::PubComp(p2)) => p1 == p2, + (Self::Subscribe(p1), Self::Subscribe(p2)) => p1 == p2, + (Self::SubAck(p1), Self::SubAck(p2)) => p1 == p2, + (Self::PingReq, Self::PingReq) + | (Self::PingResp, Self::PingResp) + | (Self::Disconnect, Self::Disconnect) => true, + (Self::Unsubscribe(p1), Self::Unsubscribe(p2)) => p1 == p2, + (Self::UnsubAck(p1), Self::UnsubAck(p2)) => p1 == p2, + _ => false, + } + } +} + +impl Eq for Request {} + #[cfg(feature = "websocket")] type RequestModifierFn = Arc< dyn Fn(http::Request<()>) -> Pin> + Send>> @@ -312,6 +353,27 @@ impl MqttOptions { self.connect_properties.clone() } + /// set session expiry interval on connection properties + pub fn set_session_expiry_interval(&mut self, interval: Option) -> &mut Self { + if let Some(conn_props) = &mut self.connect_properties { + conn_props.session_expiry_interval = interval; + self + } else { + let mut conn_props = ConnectProperties::new(); + conn_props.session_expiry_interval = interval; + self.set_connect_properties(conn_props) + } + } + + /// get session expiry interval on connection properties + pub fn session_expiry_interval(&self) -> Option { + if let Some(conn_props) = &self.connect_properties { + conn_props.session_expiry_interval + } else { + None + } + } + /// set receive maximum on connection properties pub fn set_receive_maximum(&mut self, recv_max: Option) -> &mut Self { if let Some(conn_props) = &mut self.connect_properties { diff --git a/rumqttc/src/v5/state.rs b/rumqttc/src/v5/state.rs index e7a99da36..d11d5cd27 100644 --- a/rumqttc/src/v5/state.rs +++ b/rumqttc/src/v5/state.rs @@ -402,17 +402,18 @@ impl MqttState { .outgoing_pub .remove(&puback.pkid) .ok_or(StateError::Unsolicited(puback.pkid))?; - tx.success(); self.inflight -= 1; if puback.reason != PubAckReason::Success && puback.reason != PubAckReason::NoMatchingSubscribers { + tx.error(NoticeError::V5PubAck(puback.reason)); return Err(StateError::PubAckFail { reason: puback.reason, }); } + tx.success(); let packet = self.check_collision(puback.pkid).map(|(publish, tx)| { self.outgoing_pub @@ -443,6 +444,7 @@ impl MqttState { if pubrec.reason != PubRecReason::Success && pubrec.reason != PubRecReason::NoMatchingSubscribers { + tx.error(NoticeError::V5PubRec(pubrec.reason)); return Err(StateError::PubRecFail { reason: pubrec.reason, }); @@ -484,12 +486,23 @@ impl MqttState { error!("Unsolicited pubcomp packet: {:?}", pubcomp.pkid); return Err(StateError::Unsolicited(pubcomp.pkid)); } - self.outgoing_rel + + let tx = self + .outgoing_rel .remove(&pubcomp.pkid) - .ok_or(StateError::Unsolicited(pubcomp.pkid))? - .success(); + .ok_or(StateError::Unsolicited(pubcomp.pkid))?; self.inflight -= 1; + + if pubcomp.reason != PubCompReason::Success + { + tx.error(NoticeError::V5PubComp(pubcomp.reason)); + return Err(StateError::PubCompFail { + reason: pubcomp.reason, + }); + } + tx.success(); + let packet = self.check_collision(pubcomp.pkid).map(|(publish, tx)| { self.outgoing_pub .insert(pubcomp.pkid, (publish.clone(), tx)); diff --git a/rumqttc/tests/broker.rs b/rumqttc/tests/broker.rs index ea66448f5..760a2ab37 100644 --- a/rumqttc/tests/broker.rs +++ b/rumqttc/tests/broker.rs @@ -21,7 +21,7 @@ pub struct Broker { impl Broker { /// Create a new broker which accepts 1 mqtt connection - pub async fn new(port: u16, connack: u8) -> Broker { + pub async fn new(port: u16, connack: u8, session_saved: bool) -> Broker { let addr = format!("127.0.0.1:{port}"); let listener = TcpListener::bind(&addr).await.unwrap(); @@ -32,9 +32,12 @@ impl Broker { framed.readb(&mut incoming).await.unwrap(); match incoming.pop_front().unwrap() { - Packet::Connect(_) => { + Packet::Connect(connect) => { let connack = match connack { - 0 => ConnAck::new(ConnectReturnCode::Success, false), + 0 => ConnAck::new( + ConnectReturnCode::Success, + !connect.clean_session && session_saved, + ), 1 => ConnAck::new(ConnectReturnCode::BadUserNamePassword, false), _ => { return Broker { diff --git a/rumqttc/tests/reliability.rs b/rumqttc/tests/reliability.rs index 0a83d57ce..633ca4706 100644 --- a/rumqttc/tests/reliability.rs +++ b/rumqttc/tests/reliability.rs @@ -72,7 +72,7 @@ async fn _tick( #[tokio::test] async fn connection_should_timeout_on_time() { task::spawn(async move { - let _broker = Broker::new(1880, 3).await; + let _broker = Broker::new(1880, 3, false).await; time::sleep(Duration::from_secs(10)).await; }); @@ -125,7 +125,7 @@ async fn idle_connection_triggers_pings_on_time() { run(&mut eventloop, false).await.unwrap(); }); - let mut broker = Broker::new(1885, 0).await; + let mut broker = Broker::new(1885, 0, false).await; let mut count = 0; let mut start = Instant::now(); @@ -169,7 +169,7 @@ async fn some_outgoing_and_no_incoming_should_trigger_pings_on_time() { run(&mut eventloop, false).await.unwrap(); }); - let mut broker = Broker::new(1886, 0).await; + let mut broker = Broker::new(1886, 0, false).await; let mut count = 0; let mut start = Instant::now(); @@ -204,7 +204,7 @@ async fn some_incoming_and_no_outgoing_should_trigger_pings_on_time() { run(&mut eventloop, false).await.unwrap(); }); - let mut broker = Broker::new(2000, 0).await; + let mut broker = Broker::new(2000, 0, false).await; let mut count = 0; // Start sending qos 0 publishes to the client. This triggers @@ -238,7 +238,7 @@ async fn detects_halfopen_connections_in_the_second_ping_request() { // A broker which consumes packets but doesn't reply task::spawn(async move { - let mut broker = Broker::new(2001, 0).await; + let mut broker = Broker::new(2001, 0, false).await; broker.blackhole().await; }); @@ -279,7 +279,7 @@ async fn requests_are_blocked_after_max_inflight_queue_size() { run(&mut eventloop, false).await.unwrap(); }); - let mut broker = Broker::new(1887, 0).await; + let mut broker = Broker::new(1887, 0, false).await; for i in 1..=10 { let packet = broker.read_publish().await; @@ -306,7 +306,7 @@ async fn requests_are_recovered_after_inflight_queue_size_falls_below_max() { run(&mut eventloop, true).await.unwrap(); }); - let mut broker = Broker::new(1888, 0).await; + let mut broker = Broker::new(1888, 0, false).await; // packet 1, 2, and 3 assert!(broker.read_publish().await.is_some()); @@ -341,7 +341,7 @@ async fn packet_id_collisions_are_detected_and_flow_control_is_applied() { }); task::spawn(async move { - let mut broker = Broker::new(1891, 0).await; + let mut broker = Broker::new(1891, 0, false).await; // read all incoming packets first for i in 1..=4 { @@ -449,8 +449,8 @@ async fn next_poll_after_connect_failure_reconnects() { let options = MqttOptions::new("dummy", "127.0.0.1", 3000); task::spawn(async move { - let _broker = Broker::new(3000, 1).await; - let _broker = Broker::new(3000, 0).await; + let _broker = Broker::new(3000, 1, false).await; + let _broker = Broker::new(3000, 0, false).await; time::sleep(Duration::from_secs(15)).await; }); @@ -474,7 +474,9 @@ async fn next_poll_after_connect_failure_reconnects() { #[tokio::test] async fn reconnection_resumes_from_the_previous_state() { let mut options = MqttOptions::new("dummy", "127.0.0.1", 3001); - options.set_keep_alive(Duration::from_secs(5)); + options + .set_keep_alive(Duration::from_secs(5)) + .set_clean_session(false); // start sending qos0 publishes. Makes sure that there is out activity but no in activity let (client, mut eventloop) = AsyncClient::new(options, 5); @@ -489,7 +491,7 @@ async fn reconnection_resumes_from_the_previous_state() { }); // broker connection 1 - let mut broker = Broker::new(3001, 0).await; + let mut broker = Broker::new(3001, 0, false).await; for i in 1..=2 { let packet = broker.read_publish().await.unwrap(); assert_eq!(i, packet.payload[0]); @@ -503,7 +505,7 @@ async fn reconnection_resumes_from_the_previous_state() { // a block around broker with {} is closing the connection as expected // broker connection 2 - let mut broker = Broker::new(3001, 0).await; + let mut broker = Broker::new(3001, 0, true).await; for i in 3..=4 { let packet = broker.read_publish().await.unwrap(); assert_eq!(i, packet.payload[0]); @@ -514,7 +516,9 @@ async fn reconnection_resumes_from_the_previous_state() { #[tokio::test] async fn reconnection_resends_unacked_packets_from_the_previous_connection_first() { let mut options = MqttOptions::new("dummy", "127.0.0.1", 3002); - options.set_keep_alive(Duration::from_secs(5)); + options + .set_keep_alive(Duration::from_secs(5)) + .set_clean_session(false); // start sending qos0 publishes. this makes sure that there is // outgoing activity but no incoming activity @@ -530,14 +534,14 @@ async fn reconnection_resends_unacked_packets_from_the_previous_connection_first }); // broker connection 1. receive but don't ack - let mut broker = Broker::new(3002, 0).await; + let mut broker = Broker::new(3002, 0, false).await; for i in 1..=2 { let packet = broker.read_publish().await.unwrap(); assert_eq!(i, packet.payload[0]); } // broker connection 2 receives from scratch - let mut broker = Broker::new(3002, 0).await; + let mut broker = Broker::new(3002, 0, true).await; for i in 1..=6 { let packet = broker.read_publish().await.unwrap(); assert_eq!(i, packet.payload[0]); @@ -559,7 +563,7 @@ async fn state_is_being_cleaned_properly_and_pending_request_calculated_properly }); task::spawn(async move { - let mut broker = Broker::new(3004, 0).await; + let mut broker = Broker::new(3004, 0, false).await; while (broker.read_packet().await).is_some() { time::sleep(Duration::from_secs_f64(0.5)).await; }