diff --git a/Cargo.toml b/Cargo.toml index c7f9e18..2b690e3 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -19,10 +19,10 @@ maintenance = { status = "actively-developed" } [dependencies] embedded-hal = { version = "1.0.0-alpha.4" } +embedded-nal = "0.2.0" nb = "^1" heapless = { version = "^0.5.5", features = ["serde"] } no-std-net = "0.4.0" -embedded-nal = { git = "https://github.com/BlackbirdHQ/embedded-nal", branch = "factbird-mini-1.0" } mqttrs = { version = "0.4.0", default-features = false } defmt = { version = "^0.1" } diff --git a/examples/common/network.rs b/examples/common/network.rs index d69ecf1..c1bf0cc 100644 --- a/examples/common/network.rs +++ b/examples/common/network.rs @@ -1,4 +1,4 @@ -use embedded_nal::{AddrType, Dns, Mode, SocketAddr, TcpStack}; +use embedded_nal::{AddrType, Dns, SocketAddr, TcpClient}; use heapless::{consts, String}; use no_std_net::IpAddr; use std::io::{ErrorKind, Read, Write}; @@ -10,12 +10,11 @@ pub struct Network; pub struct TcpSocket { pub stream: Option, - mode: Mode, } impl TcpSocket { - pub fn new(mode: Mode) -> Self { - TcpSocket { stream: None, mode } + pub fn new() -> Self { + TcpSocket { stream: None } } } @@ -33,24 +32,15 @@ impl Dns for Network { } } -impl TcpStack for Network { +impl TcpClient for Network { type Error = (); type TcpSocket = TcpSocket; - fn open(&self, mode: Mode) -> Result { - Ok(TcpSocket::new(mode)) + fn socket(&self) -> Result { + Ok(TcpSocket::new()) } - fn read_with(&self, network: &mut Self::TcpSocket, f: F) -> nb::Result - where - F: FnOnce(&[u8], Option<&[u8]>) -> usize, - { - let buf = &mut [0u8; 512]; - let len = self.read(network, buf)?; - Ok(f(&mut buf[..len], None)) - } - - fn read( + fn receive( &self, network: &mut Self::TcpSocket, buf: &mut [u8], @@ -65,7 +55,7 @@ impl TcpStack for Network { } } - fn write( + fn send( &self, network: &mut Self::TcpSocket, buf: &[u8], @@ -83,33 +73,12 @@ impl TcpStack for Network { fn connect( &self, - network: Self::TcpSocket, + network: &mut Self::TcpSocket, remote: SocketAddr, - ) -> Result { - Ok(match TcpStream::connect(format!("{}", remote)) { - Ok(stream) => { - match network.mode { - Mode::Blocking => { - stream.set_write_timeout(None).unwrap(); - stream.set_read_timeout(None).unwrap(); - } - Mode::NonBlocking => panic!("Nonblocking socket mode not supported!"), - Mode::Timeout(t) => { - stream - .set_write_timeout(Some(std::time::Duration::from_millis(t as u64))) - .unwrap(); - stream - .set_read_timeout(Some(std::time::Duration::from_millis(t as u64))) - .unwrap(); - } - }; - TcpSocket { - stream: Some(stream), - mode: network.mode, - } - } - Err(_e) => return Err(()), - }) + ) -> nb::Result<(), Self::Error> { + TcpStream::connect(format!("{}", remote)) + .map(|stream| drop(network.stream.replace(stream))) + .map_err(|_| ().into()) } fn close(&self, _network: Self::TcpSocket) -> Result<(), Self::Error> { diff --git a/examples/linux.rs b/examples/linux.rs index d20342c..f0b456a 100644 --- a/examples/linux.rs +++ b/examples/linux.rs @@ -1,8 +1,9 @@ mod common; +use embedded_nal::{AddrType, Dns, TcpClient}; use mqttrust::{ - MqttEvent, MqttOptions, Notification, PublishRequest, QoS, Request, SubscribeRequest, - SubscribeTopic, + EventLoop, MqttOptions, Notification, PublishRequest, QoS, Request, SubscribeRequest, + SubscribeTopic, TcpSession, }; use common::network::Network; @@ -17,21 +18,26 @@ fn main() { let (mut p, c) = unsafe { Q.split() }; let network = Network; - // let network = std_embedded_nal::STACK; + let mut socket = network.socket().unwrap(); // Connect to broker.hivemq.com:1883 - let mut mqtt_eventloop = MqttEvent::new( - c, - SysTimer::new(), - MqttOptions::new("mqtt_test_client_id", "broker.hivemq.com".into(), 1883), - ); + let broker_addr = network + .gethostbyname("broker.hivemq.com", AddrType::Either) + .unwrap(); + network + .connect(&mut socket, (broker_addr, 1883).into()) + .expect("TCP client cannot connect to the broker"); + let mut session = TcpSession::from(socket); + let mut mqtt_eventloop = + EventLoop::new(c, SysTimer::new(), MqttOptions::new("mqtt_test_client_id")); - nb::block!(mqtt_eventloop.connect(&network)).expect("Failed to connect to MQTT"); + nb::block!(mqtt_eventloop.connect(&network, &mut session)) + .expect("MQTT client's connection request failed"); thread::Builder::new() .name("eventloop".to_string()) .spawn(move || loop { - match nb::block!(mqtt_eventloop.yield_event(&network)) { + match nb::block!(mqtt_eventloop.yield_event(&network, &mut session)) { Ok(Notification::Publish(_publish)) => { // defmt::debug!( // "[{}, {:?}]: {:?}", diff --git a/src/eventloop.rs b/src/eventloop.rs new file mode 100644 index 0000000..7f1bb4a --- /dev/null +++ b/src/eventloop.rs @@ -0,0 +1,617 @@ +use crate::requests::{PublishPayload, Request}; +use crate::state::{MqttConnectionStatus, MqttState}; +use crate::MqttOptions; +use crate::{EventError, NetworkError, Notification}; +use core::convert::{AsMut, TryInto}; +use core::ops::RangeTo; +use embedded_nal::TcpClient; +use heapless::{consts, spsc, ArrayLength, Vec}; +use mqttrs::{decode_slice, encode_slice, Connect, Packet, Protocol, QoS}; + +/// Encapsulate application-layer transaction. For example, an implementer can +/// be a raw TCP traffic, TLS or WebSocket. Typically `T` implements `TcpClient` +/// and `Self` would be a newtype wrapping `TcpClient::TcpSocket`. The +/// underlying socket must be connected to a broker in advance and ready for +/// read/write, e.g., having done SSL/TLS handshake if it is a TLS socket. It is +/// also a user's responsibility to disconnect it on abort notification. A +/// buffer is supposed to contain contiguous MQTT payloads encoded without +/// encryption. +pub trait Session { + /// An error representing transaction failures. + type Error; + /// Non-blocking read from the underlying socket. + fn try_read(&mut self, stack: &T, buffer: &mut [u8]) -> nb::Result; + /// Non-blocking write to the underlying socket. + fn try_write(&mut self, stack: &T, buffer: &[u8]) -> nb::Result; +} + +/// A newtype wrapping a socket of type `S`, given another type `T` of +/// `TcpClient`. +pub struct TcpSession(S); + +impl TcpSession { + /// Releases the underlying socket. + pub fn free(self) -> S { + self.0 + } +} + +impl From for TcpSession { + fn from(socket: S) -> Self { + Self(socket) + } +} + +impl Session for TcpSession +where + T: TcpClient, +{ + type Error = ::Error; + fn try_read(&mut self, stack: &T, buffer: &mut [u8]) -> nb::Result { + stack.receive(&mut self.0, buffer) + } + + fn try_write(&mut self, stack: &T, buffer: &[u8]) -> nb::Result { + stack.send(&mut self.0, buffer) + } +} + +pub struct EventLoop<'a, 'b, L, O, P> +where + L: ArrayLength>, + P: PublishPayload, +{ + /// Current state of the connection + pub state: MqttState

, + /// Last outgoing packet time + pub last_outgoing_timer: O, + /// Options of the current mqtt connection + pub options: MqttOptions<'b>, + /// Request stream + pub requests: spsc::Consumer<'a, Request

, L, u8>, + tx_buf: Vec, + rx_buf: PacketBuffer, +} + +impl<'a, 'b, L, O, P> EventLoop<'a, 'b, L, O, P> +where + L: ArrayLength>, + O: embedded_hal::timer::CountDown, + O::Time: From, + P: PublishPayload + Clone, +{ + pub fn new( + requests: spsc::Consumer<'a, Request

, L, u8>, + outgoing_timer: O, + options: MqttOptions<'b>, + ) -> Self { + Self { + state: MqttState::new(), + last_outgoing_timer: outgoing_timer, + options, + requests, + tx_buf: Vec::new(), + rx_buf: PacketBuffer::new(), + } + } + + fn should_handle_request(&mut self) -> bool { + let qos_space = self.state.outgoing_pub.len() < self.options.inflight(); + + let qos_0 = if let Some(Request::Publish(p)) = self.requests.peek() { + p.qos == QoS::AtMostOnce + } else { + false + }; + + if qos_0 { + true + } else { + self.requests.ready() && qos_space + } + } + + pub fn yield_event(&mut self, stack: &T, session: &mut S) -> nb::Result + where + S: Session, + { + let packet_buf = &mut [0u8; 1024]; + let o = if self.should_handle_request() { + // Handle requests + let request = unsafe { self.requests.dequeue_unchecked() }; + self.state + .handle_outgoing_request(request, packet_buf) + .map_err(EventError::from) + } else if self.last_outgoing_timer.try_wait().is_ok() { + // Handle ping + self.state + .handle_outgoing_packet(Packet::Pingreq) + .map_err(EventError::from) + } else { + self.receive(stack, session) + }; + + let (notification, outpacket) = match o { + Ok((n, p)) => (n, p), + Err(e) => { + defmt::debug!("Got an error while handling the incoming packet."); + return Ok(Notification::Abort(e)); + } + }; + + if let Some(p) = outpacket { + if let Err(e) = self.send(stack, session, p) { + defmt::debug!("Failed to send an outgoing packet."); + return Ok(Notification::Abort(e)); + } else { + self.last_outgoing_timer + .try_start(self.options.keep_alive_ms()) + .ok(); + } + } + + if let Some(n) = notification { + Ok(n) + } else { + Err(nb::Error::WouldBlock) + } + } + + pub fn send<'d, T, S>( + &mut self, + stack: &T, + session: &mut S, + pkt: Packet<'d>, + ) -> Result + where + S: Session, + { + let capacity = self.tx_buf.capacity(); + self.tx_buf.clear(); + self.tx_buf + .resize(capacity, 0x00u8) + .unwrap_or_else(|()| unreachable!("Input length equals to the current capacity.")); + let size = encode_slice(&pkt, self.tx_buf.as_mut())?; + nb::block!(session.try_write(stack, &self.tx_buf[..size])).map_err(|_| { + defmt::error!("[send] NetworkError::Write"); + EventError::Network(NetworkError::Write) + }) + } + + pub fn receive( + &mut self, + stack: &T, + session: &mut S, + ) -> Result<(Option, Option>), EventError> + where + S: Session, + { + match self.rx_buf.receive(stack, session) { + Err(nb::Error::WouldBlock) => return Ok((None, None)), + Err(_) => return Err(EventError::Network(NetworkError::Read)), + _ => {} + } + + PacketDecoder::new(&mut self.state, &mut self.rx_buf).try_into() + } + + pub fn connect(&mut self, stack: &T, session: &mut S) -> nb::Result + where + S: Session, + { + match self.state.connection_status { + MqttConnectionStatus::Connected => Ok(false), + MqttConnectionStatus::Disconnected => { + defmt::info!("MQTT connecting.."); + self.state.await_pingresp = false; + self.state.outgoing_pub.clear(); + + let (username, password) = self.options.credentials(); + + let connect = Connect { + protocol: Protocol::MQTT311, + keep_alive: (self.options.keep_alive_ms() / 1000) as u16, + client_id: self.options.client_id(), + clean_session: self.options.clean_session(), + last_will: self.options.last_will(), + username, + password, + }; + + // mqtt connection with timeout + match self.send(stack, session, connect.into()) { + Ok(_) => { + self.state + .handle_outgoing_connect() + .map_err(|e| nb::Error::Other(e.into()))?; + + self.last_outgoing_timer.try_start(50000).ok(); + } + Err(e) => { + defmt::debug!("Failed to send a connect packet."); + return Err(nb::Error::Other(e)); + } + } + + Err(nb::Error::WouldBlock) + } + MqttConnectionStatus::Handshake => { + if self.last_outgoing_timer.try_wait().is_ok() { + defmt::debug!("Handshake timed out"); + return Err(nb::Error::Other(EventError::Timeout)); + } + + self.receive(stack, session) + .map_err(|e| nb::Error::Other(e.into())) + .and_then(|(n, p)| { + if n.is_none() && p.is_none() { + return Err(nb::Error::WouldBlock); + } + Ok(n.map(|n| n == Notification::ConnAck).unwrap_or(false)) + }) + } + } + } +} + +/// A placeholder that keeps a buffer and constructs a packet incrementally. +/// Given that underlying `TcpClient` throws `WouldBlock` in a non-blocking +/// manner, its packet construction won't block either. +struct PacketBuffer { + range: RangeTo, + buffer: Vec, +} + +impl PacketBuffer { + fn new() -> Self { + let range = ..0; + let buffer = Vec::new(); + let mut buf = Self { range, buffer }; + buf.init(); + buf + } + + /// Fills the buffer with all 0s + fn init(&mut self) { + self.range.end = 0; + self.buffer.clear(); + self.buffer + .resize(self.buffer.capacity(), 0x00u8) + .unwrap_or_else(|()| unreachable!("Length equals to the current capacity.")); + } + + /// Returns a remaining fresh part of the buffer. + fn buffer(&mut self) -> &mut [u8] { + let range = self.range.end..; + self.buffer[range].as_mut() + } + + /// After decoding a packet, overwrite the used bytes by shifting the buffer + /// by its length. Assumes the length fits within the buffer's capacity. + fn rotate(&mut self, length: usize) { + self.buffer.copy_within(length.., 0); + self.range.end -= length; + self.buffer.truncate(self.buffer.capacity() - length); + self.buffer + .resize(self.buffer.capacity(), 0) + .unwrap_or_else(|()| unreachable!("Length equals to the current capacity.")); + } + + /// Receives bytes from a network socket in non-blocking mode. If incoming + /// bytes found, the range gets extended covering them. + fn receive(&mut self, stack: &T, session: &mut S) -> nb::Result<(), EventError> + where + S: Session, + { + let buffer = self.buffer(); + let len = session + .try_read(stack, buffer) + .map_err(|e| e.map(|_| EventError::Network(NetworkError::Read)))?; + self.range.end += len; + Ok(()) + } +} + +/// Provides contextual information for decoding packets. If an incoming packet +/// is well-formed and has a packet type the underlying state expects, returns a +/// notification. On an error, cleans up its buffer state. +struct PacketDecoder<'a, P> +where + P: PublishPayload + Clone, +{ + state: &'a mut MqttState

, + packet_buffer: &'a mut PacketBuffer, + is_err: Option, +} + +impl<'a, P> PacketDecoder<'a, P> +where + P: PublishPayload + Clone, +{ + fn new(state: &'a mut MqttState

, packet_buffer: &'a mut PacketBuffer) -> Self { + Self { + state, + packet_buffer, + is_err: None, + } + } + + // https://docs.oasis-open.org/mqtt/mqtt/v3.1.1/os/mqtt-v3.1.1-os.html#_Toc398718023 + fn packet_length(&self) -> Option { + // The result of earlier decode_slice failed with an error or incomplete + // packet. + if self.is_err.unwrap_or(true) { + return None; + } + + // The buffer contains a valid packet. + self.packet_buffer + .buffer + .iter() + .skip(1) + .take(4) + .scan(true, |continuation, byte| { + let has_successor = byte & 0x80 != 0x00; + let length = (byte & 0x7f) as usize; + if *continuation { + *continuation = has_successor; + length.into() + } else { + // Short-circuit + None + } + }) + .enumerate() + .fold(1, |acc, (i, length)| { + acc + 1 + length * 0x80_usize.pow(i as u32) + }) + .into() + } +} + +impl<'a, P> TryInto<(Option, Option>)> for PacketDecoder<'a, P> +where + P: PublishPayload + Clone, +{ + type Error = EventError; + fn try_into(mut self) -> Result<(Option, Option>), Self::Error> { + let buffer = self.packet_buffer.buffer[self.packet_buffer.range].as_ref(); + match decode_slice(buffer) { + Err(_e) => { + self.is_err.replace(true); + Err(EventError::Network(NetworkError::Read)) + } + Ok(Some(packet)) => { + self.is_err.replace(false); + self.state + .handle_incoming_packet(packet) + .map_err(EventError::from) + } + Ok(None) => Ok((None, None)), + } + } +} + +impl<'a, P> Drop for PacketDecoder<'a, P> +where + P: PublishPayload + Clone, +{ + fn drop(&mut self) { + if let Some(is_err) = self.is_err { + if is_err { + self.packet_buffer.init(); + } else { + let length = self + .packet_length() + .unwrap_or_else(|| unreachable!("A valid packet has a non-zero length.")); + self.packet_buffer.rotate(length); + } + } + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::PublishRequest; + use embedded_hal::timer::CountDown; + use heapless::{consts, spsc::Queue, String, Vec}; + use mqttrs::{Connack, ConnectReturnCode, Pid, Publish, QosPid}; + + #[derive(Debug)] + struct CdMock { + time: u32, + } + + impl CountDown for CdMock { + type Error = core::convert::Infallible; + type Time = u32; + fn try_start(&mut self, count: T) -> Result<(), Self::Error> + where + T: Into, + { + self.time = count.into(); + Ok(()) + } + fn try_wait(&mut self) -> nb::Result<(), Self::Error> { + Ok(()) + } + } + + struct MockSession { + pub should_fail_read: bool, + pub should_fail_write: bool, + } + + impl Session<()> for MockSession { + type Error = (); + + fn try_write(&mut self, _stack: &(), buffer: &[u8]) -> nb::Result { + if self.should_fail_write { + Err(nb::Error::Other(())) + } else { + Ok(buffer.len()) + } + } + + fn try_read(&mut self, _stack: &(), buffer: &mut [u8]) -> nb::Result { + if self.should_fail_read { + Err(nb::Error::Other(())) + } else { + let connack = Packet::Connack(Connack { + session_present: false, + code: ConnectReturnCode::Accepted, + }); + let size = encode_slice(&connack, buffer).unwrap(); + Ok(size) + } + } + } + + #[test] + fn success_receive_multiple_packets() { + let mut state = MqttState::>::new(); + let mut rx_buf = PacketBuffer::new(); + let connack = Connack { + session_present: false, + code: ConnectReturnCode::Accepted, + }; + let publish = Publish { + dup: false, + qospid: QosPid::AtLeastOnce(Pid::new()), + retain: false, + topic_name: "test/topic", + payload: &[0xff; 1003], + }; + + let connack_len = encode_slice(&Packet::from(connack), rx_buf.buffer()).unwrap(); + rx_buf.range.end += connack_len; + let publish_len = encode_slice(&Packet::from(publish.clone()), rx_buf.buffer()).unwrap(); + rx_buf.range.end += publish_len; + assert_eq!(rx_buf.range.end, rx_buf.buffer.capacity()); + + // Decode the first Connack packet on the Handshake state. + state.connection_status = MqttConnectionStatus::Handshake; + let (n, p) = PacketDecoder::new(&mut state, &mut rx_buf) + .try_into() + .unwrap(); + assert_eq!(n, Some(Notification::ConnAck)); + assert_eq!(p, None); + + // Decode the second Publish packet on the Connected state. + assert_eq!(state.connection_status, MqttConnectionStatus::Connected); + let (n, p) = PacketDecoder::new(&mut state, &mut rx_buf) + .try_into() + .unwrap(); + let publish_notification = match n { + Some(Notification::Publish(p)) => p, + _ => panic!(), + }; + assert_eq!(&publish_notification.payload, publish.payload); + assert_eq!(p, Some(Packet::Puback(Pid::default()))); + assert_eq!(rx_buf.range.end, 0); + assert!((0..1024).all(|i| rx_buf.buffer[i] == 0)); + } + + #[test] + fn failure_receive_multiple_packets() { + let mut state = MqttState::>::new(); + let mut rx_buf = PacketBuffer::new(); + let connack_malformed = Connack { + session_present: false, + code: ConnectReturnCode::Accepted, + }; + let publish = Publish { + dup: false, + qospid: QosPid::AtLeastOnce(Pid::new()), + retain: false, + topic_name: "test/topic", + payload: &[0xff; 1003], + }; + + let connack_malformed_len = + encode_slice(&Packet::from(connack_malformed), rx_buf.buffer()).unwrap(); + rx_buf.buffer()[3] = 6; // An invalid connect return code. + rx_buf.range.end += connack_malformed_len; + let publish_len = encode_slice(&Packet::from(publish.clone()), rx_buf.buffer()).unwrap(); + rx_buf.range.end += publish_len; + assert_eq!(rx_buf.range.end, rx_buf.buffer.capacity()); + + // When a packet is malformed, we cannot tell its length. The decoder + // discards the entire buffer. + state.connection_status = MqttConnectionStatus::Handshake; + match PacketDecoder::new(&mut state, &mut rx_buf).try_into() { + Ok((_, _)) => panic!(), + Err(e) => { + assert_eq!(e, EventError::Network(NetworkError::Read)) + } + } + assert_eq!(state.connection_status, MqttConnectionStatus::Handshake); + assert_eq!(rx_buf.range.end, 0); + assert!((0..1024).all(|i| rx_buf.buffer[i] == 0)); + } + + #[test] + #[ignore] + fn retry_behaviour() { + static mut Q: Queue>, consts::U5, u8> = + Queue(heapless::i::Queue::u8()); + + let mut session = MockSession { + should_fail_read: false, + should_fail_write: false, + }; + + let (_p, c) = unsafe { Q.split() }; + let mut event = + EventLoop::<_, _, _>::new(c, CdMock { time: 0 }, MqttOptions::new("client")); + + event + .state + .outgoing_pub + .insert( + 2, + PublishRequest { + dup: false, + qos: QoS::AtLeastOnce, + retain: false, + topic_name: String::from("some/topic/name2"), + payload: Vec::new(), + }, + ) + .unwrap(); + + event + .state + .outgoing_pub + .insert( + 3, + PublishRequest { + dup: false, + qos: QoS::AtLeastOnce, + retain: false, + topic_name: String::from("some/topic/name3"), + payload: Vec::new(), + }, + ) + .unwrap(); + + event + .state + .outgoing_pub + .insert( + 4, + PublishRequest { + dup: false, + qos: QoS::AtLeastOnce, + retain: false, + topic_name: String::from("some/topic/name4"), + payload: Vec::new(), + }, + ) + .unwrap(); + + event.state.connection_status = MqttConnectionStatus::Handshake; + event.connect(&(), &mut session).unwrap(); + } +} diff --git a/src/lib.rs b/src/lib.rs index 28a35e0..d04f920 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -1,26 +1,23 @@ #![no_std] mod client; +mod eventloop; mod options; mod requests; mod state; -use no_std_net::SocketAddr; -use state::{MqttState, StateError}; - -use core::convert::TryFrom; - -use embedded_nal::{AddrType, Dns, Mode, TcpStack}; -use heapless::{consts, spsc, ArrayLength, String, Vec}; -use mqttrs::{decode_slice, encode_slice, Pid, Suback}; - pub use client::{Mqtt, MqttClient, MqttClientError}; +use core::convert::TryFrom; +pub use eventloop::{EventLoop, Session, TcpSession}; +use heapless::{consts, String, Vec}; +use mqttrs::Pid; pub use mqttrs::{ - Connect, Packet, Protocol, Publish, QoS, QosPid, Subscribe, SubscribeReturnCodes, + Connect, Packet, Protocol, Publish, QoS, QosPid, Suback, Subscribe, SubscribeReturnCodes, SubscribeTopic, Unsubscribe, }; -pub use options::{Broker, MqttOptions}; +pub use options::MqttOptions; pub use requests::{PublishPayload, PublishRequest, Request, SubscribeRequest, UnsubscribeRequest}; +use state::StateError; #[derive(Debug, PartialEq)] pub struct PublishNotification { @@ -35,6 +32,8 @@ pub struct PublishNotification { /// happening in the eventloop #[derive(Debug, PartialEq)] pub enum Notification { + /// Incoming connection acknowledge + ConnAck, /// Incoming publish from the broker Publish(PublishNotification), /// Incoming puback from the broker @@ -97,835 +96,3 @@ impl From for EventError { EventError::MqttState(e) } } - -pub struct MqttEvent<'a, 'b, L, S, O, P> -where - L: ArrayLength>, - P: PublishPayload, -{ - /// Current state of the connection - pub state: MqttState, - /// Options of the current mqtt connection - pub options: MqttOptions<'b>, - /// Network socket - pub socket: Option, - /// Request stream - pub requests: spsc::Consumer<'a, Request

, L, u8>, - // Outgoing QoS 1, 2 publishes which aren't acked yet - // pending_pub: FnvIndexMap, consts::U3>, - // Packet ids of released QoS 2 publishes - // pending_rel: FnvIndexSet, -} - -impl<'a, 'b, L, S, O, P> MqttEvent<'a, 'b, L, S, O, P> -where - L: ArrayLength>, - O: embedded_hal::timer::CountDown, - O::Time: From, - P: PublishPayload + Clone, -{ - pub fn new( - requests: spsc::Consumer<'a, Request

, L, u8>, - outgoing_timer: O, - options: MqttOptions<'b>, - ) -> Self { - MqttEvent { - state: MqttState::new(outgoing_timer), - options, - socket: None, - requests, - // pending_pub: IndexMap::new(), - // pending_rel: IndexSet::new(), - } - } - - pub fn connect>( - &mut self, - network: &N, - ) -> nb::Result { - // connect to the broker - self.network_connect(network)?; - if self.mqtt_connect(network)? { - // Handle state after reconnect events - // self.populate_pending(); - return Ok(true); - } - - Ok(false) - } - - fn should_handle_request(&mut self) -> bool { - let qos_space = self.state.outgoing_pub.len() < self.options.inflight(); - - let qos_0 = if let Some(Request::Publish(p)) = self.requests.peek() { - p.qos == QoS::AtMostOnce - } else { - false - }; - - if qos_0 { - true - } else { - self.requests.ready() && qos_space - } - } - - pub fn yield_event>( - &mut self, - network: &N, - ) -> nb::Result { - let packet_buf = &mut [0u8; 1024]; - - let o = if self.should_handle_request() { - // Handle requests - let request = unsafe { self.requests.dequeue_unchecked() }; - self.state.handle_outgoing_request(request, packet_buf) - } else if let Some(packet) = match self.receive(network, packet_buf) { - Ok(p) => p, - Err(EventError::Encoding(e)) => { - defmt::debug!("Encoding error!"); - return Ok(Notification::Abort(e.into())); - } - Err(e) => { - defmt::debug!("Disconnecting from receive error!"); - self.disconnect(network); - return Ok(Notification::Abort(e)); - } - } { - // Handle incoming - self.state.handle_incoming_packet(packet) - // } else if let Some(p) = self.get_pending_rel() { - // // Handle pending PubRec - // self.state.handle_outgoing_packet(Packet::Pubrec(p)) - // } else if let Some(publish) = self.get_pending_pub() { - // // Handle pending Publish - // self.state - // .handle_outgoing_request(publish.into(), packet_buf) - } else if self.state.last_outgoing_timer.try_wait().is_ok() { - // Handle ping - self.state.handle_outgoing_packet(Packet::Pingreq) - } else { - Ok((None, None)) - }; - - let (notification, outpacket) = match o { - Ok((n, p)) => (n, p), - Err(e) => { - defmt::debug!("Disconnecting from handling error!"); - self.disconnect(network); - return Ok(Notification::Abort(e.into())); - } - }; - - if let Some(p) = outpacket { - if let Err(e) = self.send(network, p) { - defmt::debug!("Disconnecting from send error!"); - self.disconnect(network); - return Ok(Notification::Abort(e)); - } else { - self.state - .last_outgoing_timer - .try_start(self.options.keep_alive_ms()) - .ok(); - } - } - - if let Some(n) = notification { - Ok(n) - } else { - Err(nb::Error::WouldBlock) - } - } - - // pub fn get_pending_rel(&mut self) -> Option { - // let p = match self.pending_rel.iter().next() { - // Some(p) => *p, - // None => return None, - // }; - // self.pending_rel.remove(&p); - // Pid::try_from(p).ok() - // } - - // pub fn get_pending_pub(&mut self) -> Option> { - // let pid = match self.pending_pub.keys().next() { - // Some(p) => *p, - // None => return None, - // }; - // self.pending_pub.remove(&pid) - // } - - // fn populate_pending(&mut self) { - // let pending_pub = core::mem::replace(&mut self.state.outgoing_pub, IndexMap::new()); - - // defmt::info!("Populating pending publish: {:?}", pending_pub.len()); - - // self.pending_pub - // .extend(pending_pub.iter().map(|(key, value)| (*key, value.clone()))); - - // let pending_rel = core::mem::replace(&mut self.state.outgoing_rel, IndexSet::new()); - - // defmt::info!("populating pending rel: {:?}", pending_rel.len()); - - // self.pending_rel.extend(pending_rel.iter()); - // } - - pub fn send<'d, N: TcpStack>( - &mut self, - network: &N, - pkt: Packet<'d>, - ) -> Result { - match self.socket { - Some(ref mut socket) => { - let mut tx_buf: [u8; 1024] = [0; 1024]; - let size = encode_slice(&pkt, &mut tx_buf)?; - nb::block!(network.write(socket, &tx_buf[..size])).map_err(|_| { - defmt::error!("[send] NetworkError::Write"); - EventError::Network(NetworkError::Write) - }) - } - _ => Err(EventError::Network(NetworkError::NoSocket)), - } - } - - pub fn receive<'d, N: TcpStack>( - &mut self, - network: &N, - packet_buf: &'d mut [u8], - ) -> Result>, EventError> { - match self.socket { - Some(ref mut socket) => { - match network.read_with(socket, |a, b| parse_header(a, b, packet_buf)) { - Ok(0) | Err(nb::Error::WouldBlock) => Ok(None), - Ok(size) => { - let p = decode_slice(&packet_buf[..size]).map_err(EventError::Encoding); - - // if let Ok(Some(Packet::Puback(pid))) = p { - // defmt::info!("Got Puback! {:?}", pid.get()); - // } - p - } - _ => Err(EventError::Network(NetworkError::Read)), - } - } - _ => Err(EventError::Network(NetworkError::NoSocket)), - } - } - - fn lookup_host>( - &mut self, - network: &N, - ) -> Result<(heapless::String, SocketAddr), EventError> { - match self.options.broker() { - (Broker::Hostname(h), p) => { - let socket_addr = SocketAddr::new( - network.gethostbyname(h, AddrType::IPv4).map_err(|_e| { - defmt::info!("Failed to resolve IP!"); - EventError::Network(NetworkError::DnsLookupFailed) - })?, - p, - ); - Ok((heapless::String::from(h), socket_addr)) - } - (Broker::IpAddr(ip), p) => { - let socket_addr = SocketAddr::new(ip, p); - let domain = network.gethostbyaddr(ip).map_err(|_e| { - defmt::info!("Failed to resolve hostname!"); - EventError::Network(NetworkError::DnsLookupFailed) - })?; - - Ok((domain, socket_addr)) - } - } - } - - fn network_connect>( - &mut self, - network: &N, - ) -> Result<(), EventError> { - if let Some(socket) = &self.socket { - match network.is_connected(socket) { - Ok(true) => return Ok(()), - Err(_e) => { - self.socket = None; - return Err(EventError::Network(NetworkError::SocketClosed)); - } - Ok(false) => {} - } - }; - - self.state.connection_status = state::MqttConnectionStatus::Disconnected; - - let socket = network - .open(Mode::Blocking) - .map_err(|_e| EventError::Network(NetworkError::SocketOpen))?; - - match self.lookup_host(network) { - Ok((_hostname, socket_addr)) => { - self.socket = Some( - network - .connect(socket, socket_addr) - .map_err(|_e| EventError::Network(NetworkError::SocketConnect))?, - ); - - // if let Some(root_ca) = self.options.ca() { - // // Add root CA - // }; - - // if let Some((certificate, private_key)) = self.options.client_auth() { - // // Enable SSL for self.socket, with broker (hostname) - // }; - - defmt::debug!("Network connected!"); - - Ok(()) - } - Err(e) => { - // Make sure to cleanup socket, in case we fail DNS lookup for some reason - network - .close(socket) - .map_err(|_e| EventError::Network(NetworkError::SocketClosed))?; - Err(e) - } - } - } - - pub fn disconnect>(&mut self, network: &N) { - self.state.connection_status = state::MqttConnectionStatus::Disconnected; - if let Some(socket) = self.socket.take() { - network.close(socket).ok(); - } - } - - fn mqtt_connect>( - &mut self, - network: &N, - ) -> nb::Result { - match self.state.connection_status { - state::MqttConnectionStatus::Connected => Ok(false), - state::MqttConnectionStatus::Disconnected => { - defmt::info!("MQTT connecting.."); - self.state.await_pingresp = false; - self.state.outgoing_pub.clear(); - - let (username, password) = self.options.credentials(); - - let connect = Connect { - protocol: Protocol::MQTT311, - keep_alive: (self.options.keep_alive_ms() / 1000) as u16, - client_id: self.options.client_id(), - clean_session: self.options.clean_session(), - last_will: self.options.last_will(), - username, - password, - }; - - // mqtt connection with timeout - match self.send(network, connect.into()) { - Ok(_) => { - self.state - .handle_outgoing_connect() - .map_err(|e| nb::Error::Other(e.into()))?; - - self.state.last_outgoing_timer.try_start(50000).ok(); - } - Err(e) => { - defmt::debug!("Disconnecting from send error!"); - self.disconnect(network); - return Err(nb::Error::Other(e)); - } - } - - Err(nb::Error::WouldBlock) - } - state::MqttConnectionStatus::Handshake => { - if self.state.last_outgoing_timer.try_wait().is_ok() { - defmt::debug!("Disconnecting from handshake timeout!"); - self.disconnect(network); - return Err(nb::Error::Other(EventError::Timeout)); - } - - let packet_buf = &mut [0u8; 4]; - match self.receive(network, packet_buf) { - Ok(Some(packet)) => { - self.state - .handle_incoming_connack(packet) - .map_err(|e| nb::Error::Other(e.into()))?; - - defmt::debug!("MQTT connected!"); - - Ok(true) - } - Ok(None) => Err(nb::Error::WouldBlock), - Err(e) => Err(nb::Error::Other(e)), - } - } - } - } -} - -fn valid_header(hd: u8) -> bool { - match hd >> 4 { - 3 => true, - 6 | 8 | 10 => hd & 0x0F == 0x02, - 1..=2 | 4..=5 | 7 | 9 | 11..=14 => hd.trailing_zeros() >= 4, - _ => false, - } -} - -fn parse_header(a: &[u8], b: Option<&[u8]>, output: &mut [u8]) -> usize { - if a.is_empty() || !valid_header(a[0]) { - return 0; - } - - let mut len: usize = 0; - for pos in 0..=3 { - match b { - Some(b) if a.len() + b.len() > pos + 1 => { - // a contains atleast partial header, a + b contains rest of packet - let byte = if a.len() > pos + 1 { - a[pos + 1] - } else { - b[pos + 1 - a.len()] - }; - len += (byte as usize & 0x7F) << (pos * 7); - if (byte & 0x80) == 0 { - // Continuation bit == 0, length is parsed - let packet_len = 2 + pos + len; - if a.len() + b.len() < packet_len { - // a+b does not contain the full payload - return 0; - } else { - if output.len() < packet_len { - defmt::error!( - "Output buffer too small! {:?} < {:?}", - output.len(), - packet_len - ); - return 0; - } - - let a_copy = core::cmp::min(a.len(), packet_len); - output[..a_copy].copy_from_slice(&a[..a_copy]); - if packet_len > a_copy { - output[a_copy..packet_len].copy_from_slice(&b[..packet_len - a_copy]); - } - return packet_len; - } - } - } - None if a.len() > pos + 1 => { - // a contains the full packet - let byte = a[pos + 1]; - len += (byte as usize & 0x7F) << (pos * 7); - if (byte & 0x80) == 0 { - // Continuation bit == 0, length is parsed - let packet_len = 2 + pos + len; - if a.len() < packet_len { - // a does not contain the full payload - return 0; - } - - if output.len() < packet_len { - defmt::error!( - "Output buffer too small! {:?} < {:?}", - output.len(), - packet_len - ); - return 0; - } - output[..packet_len].copy_from_slice(&a[..packet_len]); - return packet_len; - } - } - _ => return 0, - } - } - // Continuation byte == 1 four times, that's illegal. - 0 -} - -#[cfg(test)] -mod tests { - use super::*; - use embedded_hal::timer::CountDown; - use heapless::{consts, spsc::Queue, String, Vec}; - use mqttrs::{Connack, ConnectReturnCode}; - - #[derive(Debug)] - struct CdMock { - time: u32, - } - - impl CountDown for CdMock { - type Error = core::convert::Infallible; - type Time = u32; - fn try_start(&mut self, count: T) -> Result<(), Self::Error> - where - T: Into, - { - self.time = count.into(); - Ok(()) - } - fn try_wait(&mut self) -> nb::Result<(), Self::Error> { - Ok(()) - } - } - - struct MockNetwork { - pub should_fail_read: bool, - pub should_fail_write: bool, - } - - impl Dns for MockNetwork { - type Error = (); - - fn gethostbyname( - &self, - _hostname: &str, - _addr_type: embedded_nal::AddrType, - ) -> Result { - unimplemented!() - } - fn gethostbyaddr( - &self, - _addr: embedded_nal::IpAddr, - ) -> Result, Self::Error> { - unimplemented!() - } - } - - impl TcpStack for MockNetwork { - type TcpSocket = (); - type Error = (); - - fn open(&self, _mode: embedded_nal::Mode) -> Result { - Ok(()) - } - fn connect( - &self, - _socket: Self::TcpSocket, - _remote: embedded_nal::SocketAddr, - ) -> Result { - Ok(()) - } - fn is_connected(&self, _socket: &Self::TcpSocket) -> Result { - Ok(true) - } - fn write( - &self, - _socket: &mut Self::TcpSocket, - buffer: &[u8], - ) -> nb::Result { - if self.should_fail_write { - Err(nb::Error::Other(())) - } else { - Ok(buffer.len()) - } - } - fn read( - &self, - _socket: &mut Self::TcpSocket, - buffer: &mut [u8], - ) -> nb::Result { - if self.should_fail_read { - Err(nb::Error::Other(())) - } else { - let connack = Packet::Connack(Connack { - session_present: false, - code: ConnectReturnCode::Accepted, - }); - let size = encode_slice(&connack, buffer).unwrap(); - Ok(size) - } - } - - fn read_with(&self, socket: &mut Self::TcpSocket, f: F) -> nb::Result - where - F: FnOnce(&[u8], Option<&[u8]>) -> usize, - { - let buf = &mut [0u8; 64]; - self.read(socket, buf)?; - Ok(f(buf, None)) - } - - fn close(&self, _socket: Self::TcpSocket) -> Result<(), Self::Error> { - Ok(()) - } - } - - #[test] - fn test_parse_header_puback() { - let mut out = [0u8; 128]; - let a = &[0b01000000, 0b00000010, 0, 10]; - - let len = parse_header(a, None, &mut out); - assert_eq!(len, 4); - assert_eq!(&out[..len], &a[..]); - assert_eq!( - decode_slice(&out[..len]).unwrap(), - Some(Packet::Puback(Pid::try_from(10).unwrap())) - ); - } - - #[test] - fn test_parse_header_simple() { - let mut out = [0u8; 128]; - let a = &[ - 0b00010000, 39, 0x00, 0x04, 'M' as u8, 'Q' as u8, 'T' as u8, 'T' as u8, 0x04, - 0b11001110, // +username, +password, -will retain, will qos=1, +last_will, +clean_session - 0x00, 0x0a, // 10 sec - 0x00, 0x04, 't' as u8, 'e' as u8, 's' as u8, 't' as u8, // client_id - 0x00, 0x02, '/' as u8, 'a' as u8, // will topic = '/a' - 0x00, 0x07, 'o' as u8, 'f' as u8, 'f' as u8, 'l' as u8, 'i' as u8, 'n' as u8, - 'e' as u8, // will msg = 'offline' - 0x00, 0x04, 'r' as u8, 'u' as u8, 's' as u8, 't' as u8, // username = 'rust' - 0x00, 0x02, 'm' as u8, 'q' as u8, // password = 'mq' - ]; - - let len = parse_header(a, None, &mut out); - assert_eq!(len, 41); - assert_eq!(&out[..len], &a[..]); - assert_eq!( - decode_slice(&out[..len]).unwrap(), - Some(Packet::Connect(Connect { - protocol: Protocol::MQTT311, - keep_alive: 10, - client_id: "test".into(), - clean_session: true, - last_will: Some(mqttrs::LastWill { - topic: "/a".into(), - message: b"offline", - qos: QoS::AtLeastOnce, - retain: false, - }), - username: Some("rust".into()), - password: Some(b"mq"), - })) - ); - } - - #[test] - fn test_parse_header_additional() { - let mut out = [0u8; 128]; - let a = &[ - 0b00010000, 39, 0x00, 0x04, 'M' as u8, 'Q' as u8, 'T' as u8, 'T' as u8, 0x04, - 0b11001110, // +username, +password, -will retain, will qos=1, +last_will, +clean_session - 0x00, 0x0a, // 10 sec - 0x00, 0x04, 't' as u8, 'e' as u8, 's' as u8, 't' as u8, // client_id - 0x00, 0x02, '/' as u8, 'a' as u8, // will topic = '/a' - 0x00, 0x07, 'o' as u8, 'f' as u8, 'f' as u8, 'l' as u8, 'i' as u8, 'n' as u8, - 'e' as u8, // will msg = 'offline' - 0x00, 0x04, 'r' as u8, 'u' as u8, 's' as u8, 't' as u8, // username = 'rust' - 0x00, 0x02, 'm' as u8, 'q' as u8, // password = 'mq' - 0x00, 0x01, // additional bytes - 0x00, 0x01, // additional bytes - 0x00, 0x01, // additional bytes - ]; - - let len = parse_header(a, None, &mut out); - assert_eq!(len, 41); - assert_eq!(&out[..len], &a[..len]); - assert_eq!( - decode_slice(&out[..len]).unwrap(), - Some(Packet::Connect(Connect { - protocol: Protocol::MQTT311, - keep_alive: 10, - client_id: "test".into(), - clean_session: true, - last_will: Some(mqttrs::LastWill { - topic: "/a".into(), - message: b"offline", - qos: QoS::AtLeastOnce, - retain: false, - }), - username: Some("rust".into()), - password: Some(b"mq"), - })) - ); - } - - #[test] - fn test_parse_header_wrapped_simple() { - let mut out = [0u8; 128]; - let a = &[ - 0b00010000, 39, 0x00, 0x04, 'M' as u8, 'Q' as u8, 'T' as u8, 'T' as u8, 0x04, - ]; - - let b = &[ - 0b11001110, // +username, +password, -will retain, will qos=1, +last_will, +clean_session - 0x00, 0x0a, // 10 sec - 0x00, 0x04, 't' as u8, 'e' as u8, 's' as u8, 't' as u8, // client_id - 0x00, 0x02, '/' as u8, 'a' as u8, // will topic = '/a' - 0x00, 0x07, 'o' as u8, 'f' as u8, 'f' as u8, 'l' as u8, 'i' as u8, 'n' as u8, - 'e' as u8, // will msg = 'offline' - 0x00, 0x04, 'r' as u8, 'u' as u8, 's' as u8, 't' as u8, // username = 'rust' - 0x00, 0x02, 'm' as u8, 'q' as u8, // password = 'mq' - 0x00, 0x01, // additional bytes - 0x00, 0x01, // additional bytes - 0x00, 0x01, // additional bytes - ]; - - let expected = &[ - 0b00010000, 39, 0x00, 0x04, 'M' as u8, 'Q' as u8, 'T' as u8, 'T' as u8, 0x04, - 0b11001110, // +username, +password, -will retain, will qos=1, +last_will, +clean_session - 0x00, 0x0a, // 10 sec - 0x00, 0x04, 't' as u8, 'e' as u8, 's' as u8, 't' as u8, // client_id - 0x00, 0x02, '/' as u8, 'a' as u8, // will topic = '/a' - 0x00, 0x07, 'o' as u8, 'f' as u8, 'f' as u8, 'l' as u8, 'i' as u8, 'n' as u8, - 'e' as u8, // will msg = 'offline' - 0x00, 0x04, 'r' as u8, 'u' as u8, 's' as u8, 't' as u8, // username = 'rust' - 0x00, 0x02, 'm' as u8, 'q' as u8, // password = 'mq' - ]; - - let len = parse_header(a, Some(b), &mut out); - assert_eq!(len, 41); - assert_eq!(&out[..len], &expected[..]); - assert_eq!( - decode_slice(&out[..len]).unwrap(), - Some(Packet::Connect(Connect { - protocol: Protocol::MQTT311, - keep_alive: 10, - client_id: "test".into(), - clean_session: true, - last_will: Some(mqttrs::LastWill { - topic: "/a".into(), - message: b"offline", - qos: QoS::AtLeastOnce, - retain: false, - }), - username: Some("rust".into()), - password: Some(b"mq"), - })) - ); - } - - #[test] - fn test_parse_header_wrapped_header() { - let mut out = [0u8; 128]; - let a = &[0b00010000]; - - let b = &[ - 39, 0x00, 0x04, 'M' as u8, 'Q' as u8, 'T' as u8, 'T' as u8, 0x04, - 0b11001110, // +username, +password, -will retain, will qos=1, +last_will, +clean_session - 0x00, 0x0a, // 10 sec - 0x00, 0x04, 't' as u8, 'e' as u8, 's' as u8, 't' as u8, // client_id - 0x00, 0x02, '/' as u8, 'a' as u8, // will topic = '/a' - 0x00, 0x07, 'o' as u8, 'f' as u8, 'f' as u8, 'l' as u8, 'i' as u8, 'n' as u8, - 'e' as u8, // will msg = 'offline' - 0x00, 0x04, 'r' as u8, 'u' as u8, 's' as u8, 't' as u8, // username = 'rust' - 0x00, 0x02, 'm' as u8, 'q' as u8, // password = 'mq' - 0x00, 0x01, // additional bytes - 0x00, 0x01, // additional bytes - 0x00, 0x01, // additional bytes - ]; - - let expected = &[ - 0b00010000, 39, 0x00, 0x04, 'M' as u8, 'Q' as u8, 'T' as u8, 'T' as u8, 0x04, - 0b11001110, // +username, +password, -will retain, will qos=1, +last_will, +clean_session - 0x00, 0x0a, // 10 sec - 0x00, 0x04, 't' as u8, 'e' as u8, 's' as u8, 't' as u8, // client_id - 0x00, 0x02, '/' as u8, 'a' as u8, // will topic = '/a' - 0x00, 0x07, 'o' as u8, 'f' as u8, 'f' as u8, 'l' as u8, 'i' as u8, 'n' as u8, - 'e' as u8, // will msg = 'offline' - 0x00, 0x04, 'r' as u8, 'u' as u8, 's' as u8, 't' as u8, // username = 'rust' - 0x00, 0x02, 'm' as u8, 'q' as u8, // password = 'mq' - ]; - - let len = parse_header(a, Some(b), &mut out); - assert_eq!(len, 41); - assert_eq!(&out[..len], &expected[..]); - assert_eq!( - decode_slice(&out[..len]).unwrap(), - Some(Packet::Connect(Connect { - protocol: Protocol::MQTT311, - keep_alive: 10, - client_id: "test".into(), - clean_session: true, - last_will: Some(mqttrs::LastWill { - topic: "/a".into(), - message: b"offline", - qos: QoS::AtLeastOnce, - retain: false, - }), - username: Some("rust".into()), - password: Some(b"mq"), - })) - ); - } - - #[test] - #[ignore] - fn retry_behaviour() { - static mut Q: Queue>, consts::U5, u8> = - Queue(heapless::i::Queue::u8()); - - let network = MockNetwork { - should_fail_read: false, - should_fail_write: false, - }; - - let (_p, c) = unsafe { Q.split() }; - let mut event = MqttEvent::<_, (), _, _>::new( - c, - CdMock { time: 0 }, - MqttOptions::new("client", Broker::Hostname(""), 8883), - ); - - event - .state - .outgoing_pub - .insert( - 2, - PublishRequest { - dup: false, - qos: QoS::AtLeastOnce, - retain: false, - topic_name: String::from("some/topic/name2"), - payload: Vec::new(), - }, - ) - .unwrap(); - - event - .state - .outgoing_pub - .insert( - 3, - PublishRequest { - dup: false, - qos: QoS::AtLeastOnce, - retain: false, - topic_name: String::from("some/topic/name3"), - payload: Vec::new(), - }, - ) - .unwrap(); - - event - .state - .outgoing_pub - .insert( - 4, - PublishRequest { - dup: false, - qos: QoS::AtLeastOnce, - retain: false, - topic_name: String::from("some/topic/name4"), - payload: Vec::new(), - }, - ) - .unwrap(); - - event.state.connection_status = state::MqttConnectionStatus::Handshake; - event.socket = Some(()); - - event.connect(&network).unwrap(); - - // assert_eq!(event.pending_pub.len(), 3); - - // let mut key_iter = event.pending_pub.keys(); - // assert_eq!(key_iter.next(), Some(&2)); - // assert_eq!(key_iter.next(), Some(&3)); - // assert_eq!(key_iter.next(), Some(&4)); - } -} diff --git a/src/options.rs b/src/options.rs index a489d93..563627d 100644 --- a/src/options.rs +++ b/src/options.rs @@ -1,33 +1,4 @@ use mqttrs::LastWill; -use no_std_net::{IpAddr, Ipv4Addr}; - -#[derive(Clone, Debug, PartialEq)] -pub enum Broker<'a> { - Hostname(&'a str), - IpAddr(IpAddr), -} - -impl<'a> From<&'a str> for Broker<'a> { - fn from(s: &'a str) -> Self { - Broker::Hostname(s) - } -} - -impl<'a> From for Broker<'a> { - fn from(ip: IpAddr) -> Self { - Broker::IpAddr(ip) - } -} - -impl<'a> From for Broker<'a> { - fn from(ip: Ipv4Addr) -> Self { - Broker::IpAddr(ip.into()) - } -} - -type Certificate<'a> = &'a [u8]; -type PrivateKey<'a> = &'a [u8]; -type Password<'a> = &'a [u8]; /// Options to configure the behaviour of mqtt connection /// @@ -36,26 +7,15 @@ type Password<'a> = &'a [u8]; /// - 'b: The lifetime of the packet fields, backed by a slice buffer #[derive(Clone, Debug)] pub struct MqttOptions<'a> { - /// broker address that you want to connect to - broker_addr: Broker<'a>, - /// broker port - port: u16, /// keep alive time to send pingreq to broker when the connection is idle keep_alive_ms: u32, /// clean (or) persistent session clean_session: bool, /// client identifier client_id: &'a str, - /// certificate authority certificate - ca: Option<&'a [u8]>, - /// tls client_authentication - client_auth: Option<(Certificate<'a>, PrivateKey<'a>, Option>)>, - /// alpn settings - // alpn: Option>>, /// username and password credentials: Option<(&'a str, &'a [u8])>, /// Minimum delay time between consecutive outgoing packets - // throttle: Duration, /// maximum number of outgoing inflight messages inflight: usize, /// Last will that will be issued on unexpected disconnect @@ -64,32 +24,21 @@ pub struct MqttOptions<'a> { impl<'a> MqttOptions<'a> { /// New mqtt options - pub fn new(id: &'a str, broker: Broker<'a>, port: u16) -> MqttOptions<'a> { + pub fn new(id: &'a str) -> MqttOptions<'a> { if id.starts_with(' ') || id.is_empty() { panic!("Invalid client id") } MqttOptions { - broker_addr: broker, - port, keep_alive_ms: 60_000, clean_session: true, client_id: id, - ca: None, - client_auth: None, - // alpn: None, credentials: None, - // throttle: Duration::from_micros(0), inflight: 3, last_will: None, } } - /// Broker address - pub fn broker(&self) -> (Broker, u16) { - (self.broker_addr.clone(), self.port) - } - pub fn set_last_will(self, will: LastWill<'a>) -> Self { Self { last_will: Some(will), @@ -101,44 +50,6 @@ impl<'a> MqttOptions<'a> { self.last_will.clone() } - pub fn set_ca(self, ca: &'a [u8]) -> Self { - Self { - ca: Some(ca), - ..self - } - } - - pub fn ca(&self) -> Option<&[u8]> { - self.ca - } - - pub fn set_client_auth( - self, - cert: Certificate<'a>, - key: PrivateKey<'a>, - password: Option>, - ) -> Self { - Self { - client_auth: Some((cert, key, password)), - ..self - } - } - - pub fn client_auth(&self) -> Option<(Certificate<'a>, PrivateKey<'a>, Option>)> { - self.client_auth - } - - // pub fn set_alpn(self, alpn: Vec>) -> Self { - // Self { - // alpn: Some(alpn), - // ..self - // } - // } - - // pub fn alpn(&self) -> Option>> { - // self.alpn.clone() - // } - /// Set number of seconds after which client should ping the broker /// if there is no other data exchange pub fn set_keep_alive(self, secs: u16) -> Self { @@ -225,55 +136,30 @@ impl<'a> MqttOptions<'a> { #[cfg(test)] mod test { - use super::{Ipv4Addr, MqttOptions}; - use embedded_nal::{IpAddr, Ipv6Addr}; + use super::MqttOptions; use mqttrs::LastWill; #[test] #[should_panic] fn client_id_starts_with_space() { - let _mqtt_opts = MqttOptions::new(" client_a", Ipv4Addr::new(127, 0, 0, 1).into(), 1883) - .set_clean_session(true); + let _mqtt_opts = MqttOptions::new(" client_a").set_clean_session(true); } #[test] #[should_panic] fn no_client_id() { - let _mqtt_opts = - MqttOptions::new("", Ipv4Addr::localhost().into(), 1883).set_clean_session(true); - } - - #[test] - fn broker() { - let opts = MqttOptions::new("client_a", Ipv4Addr::localhost().into(), 1883); - assert_eq!(opts.broker_addr, Ipv4Addr::localhost().into()); - assert_eq!(opts.port, 1883); - assert_eq!(opts.broker(), (Ipv4Addr::localhost().into(), 1883)); - assert_eq!( - MqttOptions::new("client_a", "localhost".into(), 1883).broker_addr, - "localhost".into() - ); - assert_eq!( - MqttOptions::new("client_a", IpAddr::V4(Ipv4Addr::localhost()).into(), 1883) - .broker_addr, - IpAddr::V4(Ipv4Addr::localhost()).into() - ); - assert_eq!( - MqttOptions::new("client_a", IpAddr::V6(Ipv6Addr::localhost()).into(), 1883) - .broker_addr, - IpAddr::V6(Ipv6Addr::localhost()).into() - ); + let _mqtt_opts = MqttOptions::new("").set_clean_session(true); } #[test] fn client_id() { - let opts = MqttOptions::new("client_a", Ipv4Addr::localhost().into(), 1883); + let opts = MqttOptions::new("client_a"); assert_eq!(opts.client_id(), "client_a"); } #[test] fn inflight() { - let opts = MqttOptions::new("client_a", Ipv4Addr::localhost().into(), 1883); + let opts = MqttOptions::new("client_a"); assert_eq!(opts.inflight, 3); assert_eq!(opts.set_inflight(5).inflight(), 5); } @@ -281,45 +167,14 @@ mod test { #[test] #[should_panic] fn zero_inflight() { - let opts = MqttOptions::new("client_a", Ipv4Addr::localhost().into(), 1883); + let opts = MqttOptions::new("client_a"); assert_eq!(opts.inflight, 3); assert_eq!(opts.set_inflight(0).inflight(), 5); } - #[test] - fn client_auth() { - let opts = MqttOptions::new("client_a", Ipv4Addr::localhost().into(), 1883); - assert_eq!(opts.client_auth, None); - assert_eq!( - opts.clone() - .set_client_auth(b"Certificate", b"PrivateKey", None) - .client_auth(), - Some((&b"Certificate"[..], &b"PrivateKey"[..], None)) - ); - assert_eq!( - opts.set_client_auth(b"Certificate", b"PrivateKey", Some(b"Password")) - .client_auth(), - Some(( - &b"Certificate"[..], - &b"PrivateKey"[..], - Some(&b"Password"[..]) - )) - ); - } - - #[test] - fn ca() { - let opts = MqttOptions::new("client_a", Ipv4Addr::localhost().into(), 1883); - assert_eq!(opts.ca, None); - assert_eq!( - opts.set_ca(b"My Certificate Authority").ca(), - Some(&b"My Certificate Authority"[..]) - ); - } - #[test] fn keep_alive_ms() { - let opts = MqttOptions::new("client_a", Ipv4Addr::localhost().into(), 1883); + let opts = MqttOptions::new("client_a"); assert_eq!(opts.keep_alive_ms, 60_000); assert_eq!(opts.set_keep_alive(120).keep_alive_ms(), 120_000); } @@ -327,14 +182,14 @@ mod test { #[test] #[should_panic] fn keep_alive_panic() { - let opts = MqttOptions::new("client_a", Ipv4Addr::localhost().into(), 1883); + let opts = MqttOptions::new("client_a"); assert_eq!(opts.keep_alive_ms, 60_000); assert_eq!(opts.set_keep_alive(4).keep_alive_ms(), 120_000); } #[test] fn last_will() { - let opts = MqttOptions::new("client_a", Ipv4Addr::localhost().into(), 1883); + let opts = MqttOptions::new("client_a"); assert_eq!(opts.last_will, None); let will = LastWill { topic: "topic", @@ -347,14 +202,14 @@ mod test { #[test] fn clean_session() { - let opts = MqttOptions::new("client_a", Ipv4Addr::localhost().into(), 1883); + let opts = MqttOptions::new("client_a"); assert_eq!(opts.clean_session, true); assert_eq!(opts.set_clean_session(false).clean_session(), false); } #[test] fn credentials() { - let opts = MqttOptions::new("client_a", Ipv4Addr::localhost().into(), 1883); + let opts = MqttOptions::new("client_a"); assert_eq!(opts.credentials, None); assert_eq!(opts.credentials(), (None, None)); assert_eq!( diff --git a/src/state.rs b/src/state.rs index 3dba2a2..e5ca2f0 100644 --- a/src/state.rs +++ b/src/state.rs @@ -39,13 +39,11 @@ pub enum StateError { /// **Generics**: /// - O: The output timer used for keeping track of keep-alive ping-pongs. Must /// implement the [`embedded_hal::timer::CountDown`] trait -pub struct MqttState { +pub struct MqttState

{ /// Connection status pub connection_status: MqttConnectionStatus, /// Status of last ping pub await_pingresp: bool, - /// Last outgoing packet time - pub last_outgoing_timer: O, /// Packet id of the last outgoing packet pub last_pid: Pid, /// Outgoing QoS 1, 2 publishes which aren't acked yet @@ -56,20 +54,17 @@ pub struct MqttState { pub incoming_pub: FnvIndexSet, } -impl MqttState +impl

MqttState

where - O: embedded_hal::timer::CountDown, - O::Time: From, P: PublishPayload + Clone, { /// Creates new mqtt state. Same state should be used during a /// connection for persistent sessions while new state should /// instantiated for clean sessions - pub fn new(outgoing_timer: O) -> Self { + pub fn new() -> Self { MqttState { connection_status: MqttConnectionStatus::Disconnected, await_pingresp: false, - last_outgoing_timer: outgoing_timer, last_pid: Pid::new(), outgoing_pub: IndexMap::new(), @@ -117,8 +112,11 @@ where pub fn handle_incoming_packet<'a>( &mut self, packet: Packet<'a>, - ) -> Result<(Option, Option>), StateError> { + ) -> Result<(Option, Option>), StateError> { match packet { + Packet::Connack(connack) => self + .handle_incoming_connack(connack) + .map(|()| (Notification::ConnAck.into(), None)), Packet::Pingresp => self.handle_incoming_pingresp(), Packet::Publish(publish) => self.handle_incoming_publish(publish), Packet::Suback(suback) => self.handle_incoming_suback(suback), @@ -190,10 +188,10 @@ where /// This should be usually ok in case of acks due to ack ordering in normal /// conditions. But in cases where the broker doesn't guarantee the order of /// acks, the performance won't be optimal - fn handle_incoming_puback<'a>( + fn handle_incoming_puback( &mut self, pid: Pid, - ) -> Result<(Option, Option>), StateError> { + ) -> Result<(Option, Option>), StateError> { if self.outgoing_pub.contains_key(&pid.get()) { let _publish = self.outgoing_pub.remove(&pid.get()); @@ -206,19 +204,19 @@ where } } - fn handle_incoming_suback<'a>( + fn handle_incoming_suback( &mut self, suback: Suback, - ) -> Result<(Option, Option>), StateError> { + ) -> Result<(Option, Option>), StateError> { let request = None; let notification = Some(Notification::Suback(suback)); Ok((notification, request)) } - fn handle_incoming_unsuback<'a>( + fn handle_incoming_unsuback( &mut self, pid: Pid, - ) -> Result<(Option, Option>), StateError> { + ) -> Result<(Option, Option>), StateError> { let request = None; let notification = Some(Notification::Unsuback(pid)); Ok((notification, request)) @@ -228,10 +226,10 @@ where /// matching packet identifier. Removal is now a O(n) operation. This should be /// usually ok in case of acks due to ack ordering in normal conditions. But in cases /// where the broker doesn't guarantee the order of acks, the performance won't be optimal - fn handle_incoming_pubrec<'a>( + fn handle_incoming_pubrec( &mut self, pid: Pid, - ) -> Result<(Option, Option>), StateError> { + ) -> Result<(Option, Option>), StateError> { if self.outgoing_pub.contains_key(&pid.get()) { let _publish = self.outgoing_pub.remove(&pid.get()); self.outgoing_rel @@ -252,7 +250,7 @@ where fn handle_incoming_publish<'a>( &mut self, publish: Publish<'a>, - ) -> Result<(Option, Option>), StateError> { + ) -> Result<(Option, Option>), StateError> { let qospid = publish.qospid; match qospid { @@ -277,10 +275,10 @@ where } } - fn handle_incoming_pubrel<'a>( + fn handle_incoming_pubrel( &mut self, pid: Pid, - ) -> Result<(Option, Option>), StateError> { + ) -> Result<(Option, Option>), StateError> { if self.incoming_pub.contains(&pid.get()) { self.incoming_pub.remove(&pid.get()); let reply = Packet::Pubcomp(pid); @@ -291,10 +289,10 @@ where } } - fn handle_incoming_pubcomp<'a>( + fn handle_incoming_pubcomp( &mut self, pid: Pid, - ) -> Result<(Option, Option>), StateError> { + ) -> Result<(Option, Option>), StateError> { if self.outgoing_rel.contains(&pid.get()) { self.outgoing_rel.remove(&pid.get()); let notification = Some(Notification::Pubcomp(pid)); @@ -323,9 +321,9 @@ where Ok(Packet::Pingreq) } - fn handle_incoming_pingresp<'a>( + fn handle_incoming_pingresp( &mut self, - ) -> Result<(Option, Option>), StateError> { + ) -> Result<(Option, Option>), StateError> { self.await_pingresp = false; defmt::trace!("Pingresp"); Ok((None, None)) @@ -361,21 +359,12 @@ where Ok(()) } - pub fn handle_incoming_connack<'a>(&mut self, packet: Packet<'a>) -> Result<(), StateError> { - let connack = match packet { - Packet::Connack(connack) => connack, - _packet => { - defmt::error!("Invalid packet. Expecting connack!",); - - self.connection_status = MqttConnectionStatus::Disconnected; - return Err(StateError::WrongPacket); - } - }; - + pub fn handle_incoming_connack<'a>(&mut self, connack: Connack) -> Result<(), StateError> { match connack.code { ConnectReturnCode::Accepted if self.connection_status == MqttConnectionStatus::Handshake => { + defmt::debug!("MQTT connected!"); self.connection_status = MqttConnectionStatus::Connected; Ok(()) } @@ -409,30 +398,9 @@ mod test { use super::{MqttConnectionStatus, MqttState, Packet, StateError}; use crate::{Notification, PublishRequest, SubscribeRequest, UnsubscribeRequest}; use core::convert::TryFrom; - use embedded_hal::timer::CountDown; use heapless::{consts, String, Vec}; use mqttrs::*; - #[derive(Debug)] - struct CdMock { - time: u32, - } - - impl CountDown for CdMock { - type Error = core::convert::Infallible; - type Time = u32; - fn try_start(&mut self, count: T) -> Result<(), Self::Error> - where - T: Into, - { - self.time = count.into(); - Ok(()) - } - fn try_wait(&mut self) -> nb::Result<(), Self::Error> { - Ok(()) - } - } - fn build_outgoing_publish<'a>(qos: QoS) -> PublishRequest> { let topic = heapless::String::from("hello/world"); let payload = Vec::from_slice(&[1, 2, 3]).unwrap(); @@ -459,9 +427,8 @@ mod test { } } - fn build_mqttstate<'a>() -> MqttState> { - let outgoing_timer = CdMock { time: 0 }; - MqttState::new(outgoing_timer) + fn build_mqttstate<'a>() -> MqttState> { + MqttState::new() } #[test]