diff --git a/roles/tests-integration/lib/sniffer.rs b/roles/tests-integration/lib/sniffer.rs index 72abc8e1e..6926ccca5 100644 --- a/roles/tests-integration/lib/sniffer.rs +++ b/roles/tests-integration/lib/sniffer.rs @@ -423,8 +423,8 @@ impl Sniffer { } } - /// used to block the test runtime - /// while we wait until Sniffer has received a message of some specific type + /// Waits until a message of the specified type is received into the `message_direction` + /// corresponding queue. pub async fn wait_for_message_type( &self, message_direction: MessageDirection, @@ -456,6 +456,56 @@ impl Sniffer { sleep(Duration::from_secs(1)).await; } } + + /// Similar to `[Sniffer::wait_for_message_type]` but also removes the messages from the queue + /// including the specified message type. + pub async fn wait_for_message_type_and_clean_queue( + &self, + message_direction: MessageDirection, + message_type: u8, + ) -> bool { + let now = std::time::Instant::now(); + loop { + let has_message_type = match message_direction { + MessageDirection::ToDownstream => self + .messages_from_upstream + .has_message_type_with_remove(message_type), + MessageDirection::ToUpstream => self + .messages_from_downstream + .has_message_type_with_remove(message_type), + }; + + // ready to unblock test runtime + if has_message_type { + return true; + } + + // 10 min timeout + // only for worst case, ideally should never be triggered + if now.elapsed().as_secs() > 10 * 60 { + panic!("Timeout waiting for message type"); + } + + // sleep to reduce async lock contention + sleep(Duration::from_secs(1)).await; + } + } + + /// Checks whether the sniffer has received a message of the specified type. + pub async fn includes_message_type( + &self, + message_direction: MessageDirection, + message_type: u8, + ) -> bool { + match message_direction { + MessageDirection::ToDownstream => { + self.messages_from_upstream.has_message_type(message_type) + } + MessageDirection::ToUpstream => { + self.messages_from_downstream.has_message_type(message_type) + } + } + } } // Utility macro to assert that the downstream and upstream roles have sent specific messages. @@ -656,6 +706,22 @@ impl MessagesAggregator { has_message } + fn has_message_type_with_remove(&self, message_type: u8) -> bool { + self.messages + .safe_lock(|messages| { + let mut cloned_messages = messages.clone(); + for (pos, (t, _)) in cloned_messages.iter().enumerate() { + if *t == message_type { + let drained = cloned_messages.drain(pos + 1..).collect(); + *messages = drained; + return true; + } + } + false + }) + .unwrap() + } + // The aggregator queues messages in FIFO order, so this function returns the oldest message in // the queue. // diff --git a/roles/tests-integration/tests/pool_integration.rs b/roles/tests-integration/tests/pool_integration.rs index e6f3446d1..94c832a0d 100644 --- a/roles/tests-integration/tests/pool_integration.rs +++ b/roles/tests-integration/tests/pool_integration.rs @@ -1,7 +1,10 @@ use integration_tests_sv2::*; use crate::sniffer::MessageDirection; -use const_sv2::{MESSAGE_TYPE_NEW_EXTENDED_MINING_JOB, MESSAGE_TYPE_NEW_TEMPLATE}; +use const_sv2::{ + MESSAGE_TYPE_MINING_SET_NEW_PREV_HASH, MESSAGE_TYPE_NEW_EXTENDED_MINING_JOB, + MESSAGE_TYPE_NEW_TEMPLATE, +}; use roles_logic_sv2::{ common_messages_sv2::{Protocol, SetupConnection}, parsers::{AnyMessage, CommonMessages, Mining, PoolMessages, TemplateDistribution}, @@ -92,24 +95,12 @@ async fn header_timestamp_value_assertion_in_new_extended_mining_job() { } _ => panic!("SetNewPrevHash not found!"), }; - // Assertions of messages between Pool and Translator Proxy (these are not necessary for the - // test itself, but they are used to pop from the sniffer's message queue) - assert_common_message!( - &pool_translator_sniffer.next_message_from_upstream(), - SetupConnectionSuccess - ); - assert_mining_message!( - &pool_translator_sniffer.next_message_from_upstream(), - OpenExtendedMiningChannelSuccess - ); - assert_mining_message!( - &pool_translator_sniffer.next_message_from_upstream(), - NewExtendedMiningJob - ); - assert_mining_message!( - &pool_translator_sniffer.next_message_from_upstream(), - SetNewPrevHash - ); + pool_translator_sniffer + .wait_for_message_type_and_clean_queue( + MessageDirection::ToDownstream, + MESSAGE_TYPE_MINING_SET_NEW_PREV_HASH, + ) + .await; // Wait for a second NewExtendedMiningJob message pool_translator_sniffer .wait_for_message_type( diff --git a/roles/tests-integration/tests/sniffer_integration.rs b/roles/tests-integration/tests/sniffer_integration.rs index 64eeaf649..84b3294b4 100644 --- a/roles/tests-integration/tests/sniffer_integration.rs +++ b/roles/tests-integration/tests/sniffer_integration.rs @@ -1,4 +1,7 @@ -use const_sv2::MESSAGE_TYPE_SETUP_CONNECTION_ERROR; +use const_sv2::{ + MESSAGE_TYPE_SETUP_CONNECTION_ERROR, MESSAGE_TYPE_SETUP_CONNECTION_SUCCESS, + MESSAGE_TYPE_SET_NEW_PREV_HASH, +}; use integration_tests_sv2::*; use roles_logic_sv2::{ common_messages_sv2::SetupConnectionError, @@ -10,7 +13,6 @@ use std::convert::TryInto; #[tokio::test] async fn test_sniffer_interrupter() { let (_tp, tp_addr) = start_template_provider(None).await; - use const_sv2::MESSAGE_TYPE_SETUP_CONNECTION_SUCCESS; let message = PoolMessages::Common(CommonMessages::SetupConnectionError(SetupConnectionError { flags: 0, @@ -33,3 +35,36 @@ async fn test_sniffer_interrupter() { assert_common_message!(&sniffer.next_message_from_downstream(), SetupConnection); assert_common_message!(&sniffer.next_message_from_upstream(), SetupConnectionError); } + +#[tokio::test] +async fn test_sniffer_wait_for_message_type_with_remove() { + let (_tp, tp_addr) = start_template_provider(None).await; + let (sniffer, sniffer_addr) = start_sniffer("".to_string(), tp_addr, false, None).await; + let _ = start_pool(Some(sniffer_addr)).await; + assert!( + sniffer + .wait_for_message_type_and_clean_queue( + MessageDirection::ToDownstream, + MESSAGE_TYPE_SET_NEW_PREV_HASH, + ) + .await + ); + assert_eq!( + sniffer + .includes_message_type( + MessageDirection::ToDownstream, + MESSAGE_TYPE_SETUP_CONNECTION_SUCCESS + ) + .await, + false + ); + assert_eq!( + sniffer + .includes_message_type( + MessageDirection::ToDownstream, + MESSAGE_TYPE_SET_NEW_PREV_HASH + ) + .await, + false + ); +}