From 82bea4e6d9073096539946ff5c505e49562dab99 Mon Sep 17 00:00:00 2001 From: Devdutt Shenoi Date: Sat, 16 Mar 2024 21:21:09 +0530 Subject: [PATCH] refactor: `read` and `write` methods on `v4::Packet` (#821) * refactor: `Packet::read` * refactor: `Packet::write` * test: fix changes in refactor --- benchmarks/parsers/v4.rs | 3 +- rumqttc/src/framed.rs | 4 +- rumqttc/src/mqttbytes/v4/mod.rs | 88 ++++++++++++++++++++------------- rumqttc/src/state.rs | 10 ++-- rumqttc/tests/broker.rs | 2 +- 5 files changed, 65 insertions(+), 42 deletions(-) diff --git a/benchmarks/parsers/v4.rs b/benchmarks/parsers/v4.rs index 8a97bf2f2..4fbde7bab 100644 --- a/benchmarks/parsers/v4.rs +++ b/benchmarks/parsers/v4.rs @@ -1,6 +1,7 @@ use bytes::{Buf, BytesMut}; use rumqttc::mqttbytes::v4; use rumqttc::mqttbytes::QoS; +use rumqttc::Packet; use std::time::Instant; mod common; @@ -31,7 +32,7 @@ fn main() { let start = Instant::now(); let mut packets = Vec::with_capacity(count); while output.has_remaining() { - let packet = v4::read(&mut output, 10 * 1024).unwrap(); + let packet = Packet::read(&mut output, 10 * 1024).unwrap(); packets.push(packet); } diff --git a/rumqttc/src/framed.rs b/rumqttc/src/framed.rs index b0a536e78..9a5d862f2 100644 --- a/rumqttc/src/framed.rs +++ b/rumqttc/src/framed.rs @@ -58,7 +58,7 @@ impl Network { pub async fn read(&mut self) -> io::Result { loop { - let required = match read(&mut self.read, self.max_incoming_size) { + let required = match Packet::read(&mut self.read, self.max_incoming_size) { Ok(packet) => return Ok(packet), Err(mqttbytes::Error::InsufficientBytes(required)) => required, Err(e) => return Err(io::Error::new(io::ErrorKind::InvalidData, e.to_string())), @@ -75,7 +75,7 @@ impl Network { pub async fn readb(&mut self, state: &mut MqttState) -> Result<(), StateError> { let mut count = 0; loop { - match read(&mut self.read, self.max_incoming_size) { + match Packet::read(&mut self.read, self.max_incoming_size) { Ok(packet) => { state.handle_incoming_packet(packet)?; diff --git a/rumqttc/src/mqttbytes/v4/mod.rs b/rumqttc/src/mqttbytes/v4/mod.rs index 4ac4b388c..3c9225e82 100644 --- a/rumqttc/src/mqttbytes/v4/mod.rs +++ b/rumqttc/src/mqttbytes/v4/mod.rs @@ -66,45 +66,67 @@ impl Packet { Self::Disconnect => Disconnect.size(), } } -} -/// Reads a stream of bytes and extracts next MQTT packet out of it -pub fn read(stream: &mut BytesMut, max_size: usize) -> Result { - let fixed_header = check(stream.iter(), max_size)?; + /// Reads a stream of bytes and extracts next MQTT packet out of it + pub fn read(stream: &mut BytesMut, max_size: usize) -> Result { + let fixed_header = check(stream.iter(), max_size)?; + + // Test with a stream with exactly the size to check border panics + let packet = stream.split_to(fixed_header.frame_length()); + let packet_type = fixed_header.packet_type()?; - // Test with a stream with exactly the size to check border panics - let packet = stream.split_to(fixed_header.frame_length()); - let packet_type = fixed_header.packet_type()?; + if fixed_header.remaining_len == 0 { + // no payload packets + return match packet_type { + PacketType::PingReq => Ok(Packet::PingReq), + PacketType::PingResp => Ok(Packet::PingResp), + PacketType::Disconnect => Ok(Packet::Disconnect), + _ => Err(Error::PayloadRequired), + }; + } - if fixed_header.remaining_len == 0 { - // no payload packets - return match packet_type { - PacketType::PingReq => Ok(Packet::PingReq), - PacketType::PingResp => Ok(Packet::PingResp), - PacketType::Disconnect => Ok(Packet::Disconnect), - _ => Err(Error::PayloadRequired), + let packet = packet.freeze(); + let packet = match packet_type { + PacketType::Connect => Packet::Connect(Connect::read(fixed_header, packet)?), + PacketType::ConnAck => Packet::ConnAck(ConnAck::read(fixed_header, packet)?), + PacketType::Publish => Packet::Publish(Publish::read(fixed_header, packet)?), + PacketType::PubAck => Packet::PubAck(PubAck::read(fixed_header, packet)?), + PacketType::PubRec => Packet::PubRec(PubRec::read(fixed_header, packet)?), + PacketType::PubRel => Packet::PubRel(PubRel::read(fixed_header, packet)?), + PacketType::PubComp => Packet::PubComp(PubComp::read(fixed_header, packet)?), + PacketType::Subscribe => Packet::Subscribe(Subscribe::read(fixed_header, packet)?), + PacketType::SubAck => Packet::SubAck(SubAck::read(fixed_header, packet)?), + PacketType::Unsubscribe => { + Packet::Unsubscribe(Unsubscribe::read(fixed_header, packet)?) + } + PacketType::UnsubAck => Packet::UnsubAck(UnsubAck::read(fixed_header, packet)?), + PacketType::PingReq => Packet::PingReq, + PacketType::PingResp => Packet::PingResp, + PacketType::Disconnect => Packet::Disconnect, }; - } - let packet = packet.freeze(); - let packet = match packet_type { - PacketType::Connect => Packet::Connect(Connect::read(fixed_header, packet)?), - PacketType::ConnAck => Packet::ConnAck(ConnAck::read(fixed_header, packet)?), - PacketType::Publish => Packet::Publish(Publish::read(fixed_header, packet)?), - PacketType::PubAck => Packet::PubAck(PubAck::read(fixed_header, packet)?), - PacketType::PubRec => Packet::PubRec(PubRec::read(fixed_header, packet)?), - PacketType::PubRel => Packet::PubRel(PubRel::read(fixed_header, packet)?), - PacketType::PubComp => Packet::PubComp(PubComp::read(fixed_header, packet)?), - PacketType::Subscribe => Packet::Subscribe(Subscribe::read(fixed_header, packet)?), - PacketType::SubAck => Packet::SubAck(SubAck::read(fixed_header, packet)?), - PacketType::Unsubscribe => Packet::Unsubscribe(Unsubscribe::read(fixed_header, packet)?), - PacketType::UnsubAck => Packet::UnsubAck(UnsubAck::read(fixed_header, packet)?), - PacketType::PingReq => Packet::PingReq, - PacketType::PingResp => Packet::PingResp, - PacketType::Disconnect => Packet::Disconnect, - }; + Ok(packet) + } - Ok(packet) + /// Serializes the MQTT packet into a stream of bytes + pub fn write(&self, stream: &mut BytesMut) -> Result { + match self { + Packet::Connect(c) => c.write(stream), + Packet::ConnAck(c) => c.write(stream), + Packet::Publish(p) => p.write(stream), + Packet::PubAck(p) => p.write(stream), + Packet::PubRec(p) => p.write(stream), + Packet::PubRel(p) => p.write(stream), + Packet::PubComp(p) => p.write(stream), + Packet::Subscribe(s) => s.write(stream), + Packet::SubAck(s) => s.write(stream), + Packet::Unsubscribe(u) => u.write(stream), + Packet::UnsubAck(u) => u.write(stream), + Packet::PingReq => PingReq.write(stream), + Packet::PingResp => PingResp.write(stream), + Packet::Disconnect => Disconnect.write(stream), + } + } } /// Return number of remaining length bytes required for encoding length diff --git a/rumqttc/src/state.rs b/rumqttc/src/state.rs index da33bd2f2..acee6f1da 100644 --- a/rumqttc/src/state.rs +++ b/rumqttc/src/state.rs @@ -703,7 +703,7 @@ mod test { let publish = build_incoming_publish(QoS::ExactlyOnce, 1); mqtt.handle_incoming_publish(&publish).unwrap(); - let packet = read(&mut mqtt.write, 10 * 1024).unwrap(); + let packet = Packet::read(&mut mqtt.write, 10 * 1024).unwrap(); match packet { Packet::PubRec(pubrec) => assert_eq!(pubrec.pkid, 1), _ => panic!("Invalid network request: {:?}", packet), @@ -770,14 +770,14 @@ mod test { let publish = build_outgoing_publish(QoS::ExactlyOnce); mqtt.outgoing_publish(publish).unwrap(); - let packet = read(&mut mqtt.write, 10 * 1024).unwrap(); + let packet = Packet::read(&mut mqtt.write, 10 * 1024).unwrap(); match packet { Packet::Publish(publish) => assert_eq!(publish.pkid, 1), packet => panic!("Invalid network request: {:?}", packet), } mqtt.handle_incoming_pubrec(&PubRec::new(1)).unwrap(); - let packet = read(&mut mqtt.write, 10 * 1024).unwrap(); + let packet = Packet::read(&mut mqtt.write, 10 * 1024).unwrap(); match packet { Packet::PubRel(pubrel) => assert_eq!(pubrel.pkid, 1), packet => panic!("Invalid network request: {:?}", packet), @@ -790,14 +790,14 @@ mod test { let publish = build_incoming_publish(QoS::ExactlyOnce, 1); mqtt.handle_incoming_publish(&publish).unwrap(); - let packet = read(&mut mqtt.write, 10 * 1024).unwrap(); + let packet = Packet::read(&mut mqtt.write, 10 * 1024).unwrap(); match packet { Packet::PubRec(pubrec) => assert_eq!(pubrec.pkid, 1), packet => panic!("Invalid network request: {:?}", packet), } mqtt.handle_incoming_pubrel(&PubRel::new(1)).unwrap(); - let packet = read(&mut mqtt.write, 10 * 1024).unwrap(); + let packet = Packet::read(&mut mqtt.write, 10 * 1024).unwrap(); match packet { Packet::PubComp(pubcomp) => assert_eq!(pubcomp.pkid, 1), packet => panic!("Invalid network request: {:?}", packet), diff --git a/rumqttc/tests/broker.rs b/rumqttc/tests/broker.rs index a6ebacc82..ea66448f5 100644 --- a/rumqttc/tests/broker.rs +++ b/rumqttc/tests/broker.rs @@ -232,7 +232,7 @@ impl Network { pub async fn readb(&mut self, incoming: &mut VecDeque) -> io::Result<()> { let mut count = 0; loop { - match read(&mut self.read, self.max_incoming_size) { + match Packet::read(&mut self.read, self.max_incoming_size) { Ok(packet) => { incoming.push_back(packet); count += 1;