Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: Handle messages republished by MQTT bridge following a disconnection event #3018

Merged
merged 9 commits into from
Jul 29, 2024
33 changes: 21 additions & 12 deletions crates/extensions/tedge_mqtt_bridge/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ use rumqttc::Publish;
use rumqttc::SubscribeFilter;
use rumqttc::Transport;
use std::borrow::Cow;
use std::collections::hash_map;
use std::collections::HashMap;
use std::collections::VecDeque;
use std::convert::Infallible;
Expand All @@ -35,6 +36,7 @@ use tedge_actors::NullSender;
use tedge_actors::RuntimeError;
use tedge_actors::RuntimeRequest;
use tedge_actors::RuntimeRequestSink;
use tracing::debug;
use tracing::info;

pub type MqttConfig = mqtt_channel::Config;
Expand Down Expand Up @@ -329,6 +331,7 @@ async fn half_bridge(
continue;
}
};
debug!("Received notification ({name}) {notification:?}");

match notification {
Event::Incoming(Incoming::ConnAck(_)) => {
Expand Down Expand Up @@ -373,19 +376,25 @@ async fn half_bridge(
}

// Keep track of packet IDs so we can acknowledge messages
Event::Outgoing(Outgoing::Publish(pkid)) => match companion_bridge_half.recv().await {
// A message was forwarded by the other bridge half, note the packet id
Some(Some((topic, msg))) => {
loop_breaker.forward_on_topic(topic, &msg);
forward_pkid_to_received_msg.insert(pkid, msg);
Event::Outgoing(Outgoing::Publish(pkid)) => {
if let hash_map::Entry::Vacant(e) = forward_pkid_to_received_msg.entry(pkid) {
match companion_bridge_half.recv().await {
// A message was forwarded by the other bridge half, note the packet id
Some(Some((topic, msg))) => {
loop_breaker.forward_on_topic(topic, &msg);
e.insert(msg);
}

// A healthcheck message was published, ignore this packet id
Some(None) => {}

// The other bridge half has disconnected, break the loop and shut down the bridge
None => break,
}
} else {
info!("Bridge cloud connection {name} ignoring already known pkid={pkid}");
}

// A healthcheck message was published, ignore this packet id
Some(None) => {}

// The other bridge half has disconnected, break the loop and shut down the bridge
None => break,
},
}
_ => {}
}
}
Expand Down
74 changes: 71 additions & 3 deletions crates/extensions/tedge_mqtt_bridge/tests/bridge.rs
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ use tokio::time::timeout;
use tracing::info;
use tracing::warn;

const DEFAULT_TIMEOUT: Duration = Duration::from_secs(5);
const DEFAULT_TIMEOUT: Duration = Duration::from_secs(3);

fn new_broker_and_client(name: &str, port: u16) -> (AsyncClient, EventLoop) {
let mut broker = Broker::new(get_rumqttd_config(port));
Expand Down Expand Up @@ -64,7 +64,7 @@ const HEALTH: &str = "te/device/main/#";
// TODO acknowledgement with lost connection bridge, check we acknowledge the correct message
#[tokio::test]
async fn bridge_many_messages() {
std::env::set_var("RUST_LOG", "tedge_mqtt_bridge=info");
std::env::set_var("RUST_LOG", "rumqttd=debug,tedge_mqtt_bridge=info");
let _ = env_logger::try_init();
let local_broker_port = free_port().await;
let cloud_broker_port = free_port().await;
Expand Down Expand Up @@ -116,6 +116,74 @@ async fn bridge_many_messages() {
next_received_message(&mut local).await.unwrap();
}

#[tokio::test]
async fn bridge_disconnect_while_sending() {
std::env::set_var("RUST_LOG", "tedge_mqtt_bridge=info");
let _ = env_logger::try_init();
let local_broker_port = free_port().await;
let cloud_broker_port = free_port().await;
let (local, mut ev_local) = new_broker_and_client("local", local_broker_port);
let (cloud, mut ev_cloud) = new_broker_and_client("cloud", cloud_broker_port);

// We can't easily restart rumqttd, so instead, we'll connect via a proxy
// that we can interrupt the connection of
let cloud_proxy = Proxy::start(cloud_broker_port).await;

let mut rules = BridgeConfig::new();
rules.forward_from_local("s/us", "c8y/", "").unwrap();
rules.forward_from_remote("s/ds", "c8y/", "").unwrap();

start_mqtt_bridge(local_broker_port, cloud_proxy.port, rules).await;

local.subscribe(HEALTH, QoS::AtLeastOnce).await.unwrap();

wait_until_health_status_is("up", &mut ev_local)
.await
.unwrap();

local.unsubscribe(HEALTH).await.unwrap();
local.subscribe("c8y/s/ds", QoS::AtLeastOnce).await.unwrap();
await_subscription(&mut ev_local).await;
cloud.subscribe("s/us", QoS::AtLeastOnce).await.unwrap();
await_subscription(&mut ev_cloud).await;

let poll_local = EventPoller::run_in_bg(ev_local);

// Verify messages are forwarded from cloud to local
for i in 1..10000 {
local
.publish(
"c8y/s/us",
QoS::AtMostOnce,
false,
format!("a,fake,smartrest,message{i}"),
)
.await
.unwrap();
}
cloud_proxy.interrupt_connections();
let _ev_cloud = EventPoller::run_in_bg(ev_cloud);
for _ in 1..10000 {
local
.publish(
"c8y/s/us",
QoS::AtMostOnce,
false,
"a,fake,smartrest,message",
)
.await
.unwrap();
}

let mut local = poll_local.stop_polling().await;
cloud
.publish("s/ds", QoS::AtLeastOnce, true, "test")
.await
.unwrap();

next_received_message(&mut local).await.unwrap();
}

#[tokio::test]
async fn bridge_reconnects_successfully_after_cloud_connection_interrupted() {
std::env::set_var("RUST_LOG", "tedge_mqtt_bridge=info");
Expand Down Expand Up @@ -398,7 +466,7 @@ impl Proxy {
};

write_socket.shutdown().await.unwrap();
write_conn.shutdown().await.unwrap();
let _ = write_conn.shutdown().await;
});
}
}
Expand Down