Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

refactor: fix flaky unit tests #3338

Merged
merged 17 commits into from
Jan 27, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
17 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

23 changes: 13 additions & 10 deletions crates/common/axum_tls/src/files.rs
Original file line number Diff line number Diff line change
Expand Up @@ -114,6 +114,7 @@ mod tests {
use assert_matches::assert_matches;
use axum::routing::get;
use axum::Router;
use camino::Utf8PathBuf;
use std::io::Cursor;

mod read_trust_store {
Expand Down Expand Up @@ -199,27 +200,28 @@ mod tests {
}

fn copy_test_file_to(test_file: &str, path: impl AsRef<Path>) -> io::Result<u64> {
std::fs::copy(format!("./test_data/{test_file}"), path)
let dir = env!("CARGO_MANIFEST_DIR");
std::fs::copy(format!("{dir}/test_data/{test_file}"), path)
}
}

#[test]
fn load_pkey_fails_when_given_x509_certificate() {
let dir = env!("CARGO_MANIFEST_DIR");
let path = Utf8PathBuf::from(format!("{dir}/test_data/ec.crt"));
assert_eq!(
load_pkey(Utf8Path::new("./test_data/ec.crt"))
.unwrap_err()
.to_string(),
"expected private key in \"./test_data/ec.crt\", found an X509 certificate"
load_pkey(&path).unwrap_err().to_string(),
format!("expected private key in {path:?}, found an X509 certificate")
);
}

#[test]
fn load_pkey_fails_when_given_certificate_revocation_list() {
let dir = env!("CARGO_MANIFEST_DIR");
let path = Utf8PathBuf::from(format!("{dir}/test_data/demo.crl"));
assert_eq!(
load_pkey(Utf8Path::new("./test_data/demo.crl"))
.unwrap_err()
.to_string(),
"expected private key in \"./test_data/demo.crl\", found a CRL"
load_pkey(&path).unwrap_err().to_string(),
format!("expected private key in {path:?}, found a CRL")
);
}

Expand Down Expand Up @@ -288,7 +290,8 @@ mod tests {
}

fn test_data(file_name: &str) -> String {
std::fs::read_to_string(format!("./test_data/{file_name}"))
let dir = env!("CARGO_MANIFEST_DIR");
std::fs::read_to_string(format!("{dir}/test_data/{file_name}"))
.with_context(|| format!("opening file {file_name} from test_data"))
.unwrap()
}
Expand Down
3 changes: 1 addition & 2 deletions crates/common/mqtt_channel/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -17,14 +17,13 @@ log = { workspace = true }
rumqttc = { workspace = true }
serde = { workspace = true }
thiserror = { workspace = true }
tokio = { workspace = true, features = ["rt", "time"] }
tokio = { workspace = true, features = ["rt", "time", "rt-multi-thread"] }
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is not used.

Suggested change
tokio = { workspace = true, features = ["rt", "time", "rt-multi-thread"] }
tokio = { workspace = true, features = ["rt", "time"] }

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It is used when running doctests as we use #[tokio::main] there. Trying to test in isolation without it breaks, and that made checking for flakiness hard.

zeroize = { workspace = true }

[dev-dependencies]
anyhow = { workspace = true }
mqtt_tests = { workspace = true }
serde_json = { workspace = true }
serial_test = { workspace = true }
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice!


[lints]
workspace = true
74 changes: 53 additions & 21 deletions crates/common/mqtt_channel/src/connection.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,10 @@ use rumqttc::EventLoop;
use rumqttc::Incoming;
use rumqttc::Outgoing;
use rumqttc::Packet;
use std::sync::Arc;
use std::time::Duration;
use tokio::sync::OwnedSemaphorePermit;
use tokio::sync::Semaphore;
use tokio::time::sleep;

/// A connection to some MQTT server
Expand Down Expand Up @@ -88,19 +91,23 @@ impl Connection {

let (mqtt_client, event_loop) =
Connection::open(config, received_sender.clone(), error_sender.clone()).await?;
let permits = Arc::new(Semaphore::new(1));
let permit = permits.clone().acquire_owned().await.unwrap();
tokio::spawn(Connection::receiver_loop(
mqtt_client.clone(),
config.clone(),
event_loop,
received_sender,
error_sender.clone(),
pub_done_sender,
permits,
));
tokio::spawn(Connection::sender_loop(
mqtt_client,
published_receiver,
error_sender,
config.last_will_message.clone(),
pub_done_sender,
permit,
));

Ok(Connection {
Expand Down Expand Up @@ -200,9 +207,41 @@ impl Connection {
mut event_loop: EventLoop,
mut message_sender: mpsc::UnboundedSender<MqttMessage>,
mut error_sender: mpsc::UnboundedSender<MqttError>,
done: oneshot::Sender<()>,
permits: Arc<Semaphore>,
) -> Result<(), MqttError> {
let mut triggered_disconnect = false;
let mut disconnect_permit = None;

loop {
match event_loop.poll().await {
// Check if we are ready to disconnect. Due to ownership of the
// event loop, this needs to be done before we call
// `event_loop.poll()`
let remaining_events_empty = event_loop.state.inflight() == 0;
if disconnect_permit.is_some() && !triggered_disconnect && remaining_events_empty {
// `sender_loop` is not running and we have no remaining
// publishes to process
let client = mqtt_client.clone();
tokio::spawn(async move { client.disconnect().await });
triggered_disconnect = true;
}

let event = tokio::select! {
// If there is an event, we need to process that first
// Otherwise we risk shutting down early
// e.g. a `Publish` request from the sender is not "inflight"
// but will immediately be returned by `event_loop.poll()`
biased;

event = event_loop.poll() => event,
permit = permits.clone().acquire_owned() => {
// The `sender_loop` has now concluded
disconnect_permit = Some(permit.unwrap());
continue;
}
};

match event {
Ok(Event::Incoming(Packet::Publish(msg))) => {
if msg.payload.len() > config.max_packet_size {
error!("Dropping message received on topic {} with payload size {} that exceeds the maximum packet size of {}",
Expand Down Expand Up @@ -266,6 +305,7 @@ impl Connection {
// No more messages will be forwarded to the client
let _ = message_sender.close().await;
let _ = error_sender.close().await;
let _ = done.send(());
Ok(())
}

Expand All @@ -274,24 +314,15 @@ impl Connection {
mut messages_receiver: mpsc::UnboundedReceiver<MqttMessage>,
mut error_sender: mpsc::UnboundedSender<MqttError>,
last_will: Option<MqttMessage>,
done: oneshot::Sender<()>,
_disconnect_permit: OwnedSemaphorePermit,
) {
loop {
match messages_receiver.next().await {
None => {
// The sender channel has been closed by the client
// No more messages will be published by the client
break;
}
Some(message) => {
let payload = Vec::from(message.payload_bytes());
if let Err(err) = mqtt_client
.publish(message.topic, message.qos, message.retain, payload)
.await
{
let _ = error_sender.send(err.into()).await;
}
}
while let Some(message) = messages_receiver.next().await {
let payload = Vec::from(message.payload_bytes());
if let Err(err) = mqtt_client
.publish(message.topic, message.qos, message.retain, payload)
.await
{
let _ = error_sender.send(err.into()).await;
}
}

Expand All @@ -303,8 +334,9 @@ impl Connection {
.publish(last_will.topic, last_will.qos, last_will.retain, payload)
.await;
}
let _ = mqtt_client.disconnect().await;
let _ = done.send(());
didier-wenzek marked this conversation as resolved.
Show resolved Hide resolved

// At this point, `_disconnect_permit` is dropped
// This allows `receiver_loop` acquire a permit and commence the shutdown process
}

pub(crate) async fn do_pause() {
Expand Down
Loading
Loading