diff --git a/crates/extensions/tedge_mqtt_bridge/src/lib.rs b/crates/extensions/tedge_mqtt_bridge/src/lib.rs index fd62a04c668..41f5059899a 100644 --- a/crates/extensions/tedge_mqtt_bridge/src/lib.rs +++ b/crates/extensions/tedge_mqtt_bridge/src/lib.rs @@ -22,6 +22,7 @@ use rumqttc::Publish; use rumqttc::SubscribeFilter; use rumqttc::Transport; use std::borrow::Cow; +use std::collections::hash_map; use std::collections::HashMap; use std::collections::VecDeque; use std::convert::Infallible; @@ -36,6 +37,7 @@ use tedge_actors::RuntimeError; use tedge_actors::RuntimeRequest; use tedge_actors::RuntimeRequestSink; use tracing::info; +use tracing::log::debug; pub type MqttConfig = mqtt_channel::Config; @@ -329,6 +331,8 @@ async fn half_bridge( continue; } }; + debug!("Received notification ({name}) {notification:?}"); + let n = format!("{:?}", notification); match notification { Event::Incoming(Incoming::ConnAck(_)) => { @@ -366,28 +370,38 @@ async fn half_bridge( Incoming::PubAck(PubAck { pkid: ack_pkid }) | Incoming::PubRec(PubRec { pkid: ack_pkid }), ) => { - if let Some(msg) = forward_pkid_to_received_msg.remove(&ack_pkid) { + if let Some(Some(msg)) = forward_pkid_to_received_msg.remove(&ack_pkid) { let target = target.clone(); tokio::spawn(async move { target.ack(&msg).await.unwrap() }); } } // Keep track of packet IDs so we can acknowledge messages - Event::Outgoing(Outgoing::Publish(pkid)) => match companion_bridge_half.recv().await { - // A message was forwarded by the other bridge half, note the packet id - Some(Some((topic, msg))) => { - loop_breaker.forward_on_topic(topic, &msg); - forward_pkid_to_received_msg.insert(pkid, msg); + Event::Outgoing(Outgoing::Publish(pkid)) => { + if let hash_map::Entry::Vacant(e) = forward_pkid_to_received_msg.entry(pkid) { + match companion_bridge_half.recv().await { + // A message was forwarded by the other bridge half, note the packet id + Some(Some((topic, msg))) => { + loop_breaker.forward_on_topic(topic, &msg); + e.insert(Some(msg)); + } + + // A healthcheck message was published, keep track of this packet id in case it's re-published + Some(None) => { + e.insert(None); + } + + // The other bridge half has disconnected, break the loop and shut down the bridge + None => break, + } + } else { + info!("Bridge cloud connection {name} ignoring already known pkid={pkid}"); } - - // A healthcheck message was published, ignore this packet id - Some(None) => {} - - // The other bridge half has disconnected, break the loop and shut down the bridge - None => break, - }, + } _ => {} } + + debug!("Processed notification ({name}) {n}"); } } diff --git a/crates/extensions/tedge_mqtt_bridge/tests/bridge.rs b/crates/extensions/tedge_mqtt_bridge/tests/bridge.rs index 4196375fad4..abf879e1ce8 100644 --- a/crates/extensions/tedge_mqtt_bridge/tests/bridge.rs +++ b/crates/extensions/tedge_mqtt_bridge/tests/bridge.rs @@ -28,7 +28,7 @@ use tokio::time::timeout; use tracing::info; use tracing::warn; -const DEFAULT_TIMEOUT: Duration = Duration::from_secs(5); +const DEFAULT_TIMEOUT: Duration = Duration::from_secs(3); fn new_broker_and_client(name: &str, port: u16) -> (AsyncClient, EventLoop) { let mut broker = Broker::new(get_rumqttd_config(port)); @@ -116,6 +116,76 @@ async fn bridge_many_messages() { next_received_message(&mut local).await.unwrap(); } +#[tokio::test] +async fn bridge_disconnect_while_sending() { + std::env::set_var("RUST_LOG", "tedge_mqtt_bridge=info"); + let _ = env_logger::try_init(); + let local_broker_port = free_port().await; + let cloud_broker_port = free_port().await; + let (local, mut ev_local) = new_broker_and_client("local", local_broker_port); + let (cloud, mut ev_cloud) = new_broker_and_client("cloud", cloud_broker_port); + + // We can't easily restart rumqttd, so instead, we'll connect via a proxy + // that we can interrupt the connection of + let cloud_proxy = Proxy::start(cloud_broker_port).await; + let local_proxy = Proxy::start(local_broker_port).await; + + let mut rules = BridgeConfig::new(); + rules.forward_from_local("s/us", "c8y/", "").unwrap(); + rules.forward_from_remote("s/ds", "c8y/", "").unwrap(); + + start_mqtt_bridge(local_proxy.port, cloud_proxy.port, rules).await; + + local.subscribe(HEALTH, QoS::AtLeastOnce).await.unwrap(); + + wait_until_health_status_is("up", &mut ev_local) + .await + .unwrap(); + + local.unsubscribe(HEALTH).await.unwrap(); + local.subscribe("c8y/s/ds", QoS::AtLeastOnce).await.unwrap(); + await_subscription(&mut ev_local).await; + cloud.subscribe("s/us", QoS::AtLeastOnce).await.unwrap(); + await_subscription(&mut ev_cloud).await; + + let poll_local = EventPoller::run_in_bg(ev_local); + + // Verify messages are forwarded from cloud to local + for i in 1..10000 { + local + .publish( + "c8y/s/us", + QoS::AtMostOnce, + false, + format!("a,fake,smartrest,message{i}"), + ) + .await + .unwrap(); + } + cloud_proxy.interrupt_connections(); + local_proxy.interrupt_connections(); + let _ev_cloud = EventPoller::run_in_bg(ev_cloud); + for _ in 1..10000 { + local + .publish( + "c8y/s/us", + QoS::AtMostOnce, + false, + "a,fake,smartrest,message", + ) + .await + .unwrap(); + } + + let mut local = poll_local.stop_polling().await; + cloud + .publish("s/ds", QoS::AtLeastOnce, true, "test") + .await + .unwrap(); + + next_received_message(&mut local).await.unwrap(); +} + #[tokio::test] async fn bridge_reconnects_successfully_after_cloud_connection_interrupted() { std::env::set_var("RUST_LOG", "tedge_mqtt_bridge=info"); @@ -386,7 +456,9 @@ impl Proxy { let mut stop = stop_rx.clone(); stop.mark_unchanged(); if let Ok((mut socket, _)) = listener.accept().await { - let mut conn = tokio::net::TcpStream::connect(&target).await.unwrap(); + let Ok(mut conn) = tokio::net::TcpStream::connect(&target).await else { + return; + }; tokio::spawn(async move { let (mut read_socket, mut write_socket) = socket.split(); let (mut read_conn, mut write_conn) = conn.split(); @@ -397,8 +469,8 @@ impl Proxy { _ = stop.changed() => info!("shutting down proxy"), }; - write_socket.shutdown().await.unwrap(); - write_conn.shutdown().await.unwrap(); + let _ = write_socket.shutdown().await; + let _ = write_conn.shutdown().await; }); } }