diff --git a/polkadot/node/network/availability-distribution/src/pov_requester/mod.rs b/polkadot/node/network/availability-distribution/src/pov_requester/mod.rs index 6f9ef9f6a9f8..4e23030aa499 100644 --- a/polkadot/node/network/availability-distribution/src/pov_requester/mod.rs +++ b/polkadot/node/network/availability-distribution/src/pov_requester/mod.rs @@ -139,6 +139,7 @@ mod tests { use futures::{executor, future}; use parity_scale_codec::Encode; + use sc_network::ProtocolName; use sp_core::testing::TaskExecutor; use polkadot_node_primitives::BlockData; @@ -231,7 +232,10 @@ mod tests { Some(Requests::PoVFetchingV1(outgoing)) => {outgoing} ); req.pending_response - .send(Ok(PoVFetchingResponse::PoV(pov.clone()).encode())) + .send(Ok(( + PoVFetchingResponse::PoV(pov.clone()).encode(), + ProtocolName::from(""), + ))) .unwrap(); break }, diff --git a/polkadot/node/network/availability-distribution/src/requester/fetch_task/tests.rs b/polkadot/node/network/availability-distribution/src/requester/fetch_task/tests.rs index 460f20499ed5..a5a81082e39a 100644 --- a/polkadot/node/network/availability-distribution/src/requester/fetch_task/tests.rs +++ b/polkadot/node/network/availability-distribution/src/requester/fetch_task/tests.rs @@ -25,7 +25,7 @@ use futures::{ Future, FutureExt, StreamExt, }; -use sc_network as network; +use sc_network::{self as network, ProtocolName}; use sp_keyring::Sr25519Keyring; use polkadot_node_network_protocol::request_response::{v1, Recipient}; @@ -252,7 +252,7 @@ impl TestRun { } } req.pending_response - .send(response.map(Encode::encode)) + .send(response.map(|r| (r.encode(), ProtocolName::from("")))) .expect("Sending response should succeed"); } return (valid_responses == 0) && self.valid_chunks.is_empty() diff --git a/polkadot/node/network/availability-distribution/src/tests/state.rs b/polkadot/node/network/availability-distribution/src/tests/state.rs index e95c1c3a27c2..66a8d8fcdcf9 100644 --- a/polkadot/node/network/availability-distribution/src/tests/state.rs +++ b/polkadot/node/network/availability-distribution/src/tests/state.rs @@ -19,6 +19,7 @@ use std::{ time::Duration, }; +use network::ProtocolName; use polkadot_node_subsystem_test_helpers::TestSubsystemContextHandle; use polkadot_node_subsystem_util::TimeoutExt; @@ -324,7 +325,11 @@ fn to_incoming_req( let response = rx.await; let payload = response.expect("Unexpected canceled request").result; pending_response - .send(payload.map_err(|_| network::RequestFailure::Refused)) + .send( + payload + .map_err(|_| network::RequestFailure::Refused) + .map(|r| (r, ProtocolName::from(""))), + ) .expect("Sending response is expected to work"); } .boxed(), diff --git a/polkadot/node/network/availability-recovery/src/tests.rs b/polkadot/node/network/availability-recovery/src/tests.rs index 1cb52757bac9..f1dc5b98c09b 100644 --- a/polkadot/node/network/availability-recovery/src/tests.rs +++ b/polkadot/node/network/availability-recovery/src/tests.rs @@ -22,13 +22,14 @@ use futures_timer::Delay; use parity_scale_codec::Encode; use polkadot_node_network_protocol::request_response::{ - self as req_res, IncomingRequest, Recipient, ReqProtocolNames, Requests, + self as req_res, v1::AvailableDataFetchingRequest, IncomingRequest, Protocol, Recipient, + ReqProtocolNames, Requests, }; use polkadot_node_subsystem_test_helpers::derive_erasure_chunks_with_proofs_and_root; use super::*; -use sc_network::{config::RequestResponseConfig, IfDisconnected, OutboundFailure, RequestFailure}; +use sc_network::{IfDisconnected, OutboundFailure, ProtocolName, RequestFailure}; use polkadot_node_primitives::{BlockData, PoV, Proof}; use polkadot_node_subsystem::messages::{ @@ -48,8 +49,18 @@ type VirtualOverseer = TestSubsystemContextHandle; // Deterministic genesis hash for protocol names const GENESIS_HASH: Hash = Hash::repeat_byte(0xff); -fn test_harness_fast_path>( - test: impl FnOnce(VirtualOverseer, RequestResponseConfig) -> T, +fn request_receiver( + req_protocol_names: &ReqProtocolNames, +) -> IncomingRequestReceiver { + let receiver = IncomingRequest::get_config_receiver(req_protocol_names); + // Don't close the sending end of the request protocol. Otherwise, the subsystem will terminate. + std::mem::forget(receiver.1.inbound_queue); + receiver.0 +} + +fn test_harness>( + subsystem: AvailabilityRecoverySubsystem, + test: impl FnOnce(VirtualOverseer) -> T, ) { let _ = env_logger::builder() .is_test(true) @@ -60,101 +71,23 @@ fn test_harness_fast_path>( - test: impl FnOnce(VirtualOverseer, RequestResponseConfig) -> T, -) { - let _ = env_logger::builder() - .is_test(true) - .filter(Some("polkadot_availability_recovery"), log::LevelFilter::Trace) - .try_init(); - - let pool = sp_core::testing::TaskExecutor::new(); - - let (context, virtual_overseer) = make_subsystem_context(pool.clone()); - - let (collation_req_receiver, req_cfg) = - IncomingRequest::get_config_receiver(&ReqProtocolNames::new(&GENESIS_HASH, None)); - let subsystem = AvailabilityRecoverySubsystem::with_chunks_only( - collation_req_receiver, - Metrics::new_dummy(), - ); - let subsystem = subsystem.run(context); - - let test_fut = test(virtual_overseer, req_cfg); + let test_fut = test(virtual_overseer); futures::pin_mut!(test_fut); futures::pin_mut!(subsystem); executor::block_on(future::join( async move { - let (mut overseer, _req_cfg) = test_fut.await; + let mut overseer = test_fut.await; overseer_signal(&mut overseer, OverseerSignal::Conclude).await; }, subsystem, )) .1 - .unwrap(); -} - -fn test_harness_chunks_if_pov_large< - T: Future, ->( - test: impl FnOnce(VirtualOverseer, RequestResponseConfig) -> T, -) { - let _ = env_logger::builder() - .is_test(true) - .filter(Some("polkadot_availability_recovery"), log::LevelFilter::Trace) - .try_init(); - - let pool = sp_core::testing::TaskExecutor::new(); - - let (context, virtual_overseer) = make_subsystem_context(pool.clone()); - - let (collation_req_receiver, req_cfg) = - IncomingRequest::get_config_receiver(&ReqProtocolNames::new(&GENESIS_HASH, None)); - let subsystem = AvailabilityRecoverySubsystem::with_chunks_if_pov_large( - collation_req_receiver, - Metrics::new_dummy(), - ); - let subsystem = subsystem.run(context); - - let test_fut = test(virtual_overseer, req_cfg); - - futures::pin_mut!(test_fut); - futures::pin_mut!(subsystem); - - executor::block_on(future::join( - async move { - let (mut overseer, _req_cfg) = test_fut.await; - overseer_signal(&mut overseer, OverseerSignal::Conclude).await; - }, - subsystem, - )) - .1 - .unwrap(); } const TIMEOUT: Duration = Duration::from_millis(300); @@ -342,11 +275,12 @@ impl TestState { async fn test_chunk_requests( &self, + req_protocol_names: &ReqProtocolNames, candidate_hash: CandidateHash, virtual_overseer: &mut VirtualOverseer, n: usize, who_has: impl Fn(usize) -> Has, - ) -> Vec, RequestFailure>>> { + ) -> Vec, ProtocolName), RequestFailure>>> { // arbitrary order. let mut i = 0; let mut senders = Vec::new(); @@ -380,7 +314,7 @@ impl TestState { let _ = req.pending_response.send( available_data.map(|r| - req_res::v1::ChunkFetchingResponse::from(r).encode() + (req_res::v1::ChunkFetchingResponse::from(r).encode(), req_protocol_names.get_name(Protocol::ChunkFetchingV1)) ) ); } @@ -394,10 +328,11 @@ impl TestState { async fn test_full_data_requests( &self, + req_protocol_names: &ReqProtocolNames, candidate_hash: CandidateHash, virtual_overseer: &mut VirtualOverseer, who_has: impl Fn(usize) -> Has, - ) -> Vec, RequestFailure>>> { + ) -> Vec, ProtocolName), RequestFailure>>> { let mut senders = Vec::new(); for _ in 0..self.validators.len() { // Receive a request for a chunk. @@ -433,9 +368,10 @@ impl TestState { let done = available_data.as_ref().ok().map_or(false, |x| x.is_some()); let _ = req.pending_response.send( - available_data.map(|r| - req_res::v1::AvailableDataFetchingResponse::from(r).encode() - ) + available_data.map(|r|( + req_res::v1::AvailableDataFetchingResponse::from(r).encode(), + req_protocol_names.get_name(Protocol::AvailableDataFetchingV1) + )) ); if done { break } @@ -532,8 +468,13 @@ impl Default for TestState { #[test] fn availability_is_recovered_from_chunks_if_no_group_provided() { let test_state = TestState::default(); + let req_protocol_names = ReqProtocolNames::new(&GENESIS_HASH, None); + let subsystem = AvailabilityRecoverySubsystem::with_fast_path( + request_receiver(&req_protocol_names), + Metrics::new_dummy(), + ); - test_harness_fast_path(|mut virtual_overseer, req_cfg| async move { + test_harness(subsystem, |mut virtual_overseer| async move { overseer_signal( &mut virtual_overseer, OverseerSignal::ActiveLeaves(ActiveLeavesUpdate::start_work(new_leaf( @@ -565,6 +506,7 @@ fn availability_is_recovered_from_chunks_if_no_group_provided() { test_state .test_chunk_requests( + &req_protocol_names, candidate_hash, &mut virtual_overseer, test_state.threshold(), @@ -600,6 +542,7 @@ fn availability_is_recovered_from_chunks_if_no_group_provided() { test_state .test_chunk_requests( + &req_protocol_names, new_candidate.hash(), &mut virtual_overseer, test_state.impossibility_threshold(), @@ -609,15 +552,20 @@ fn availability_is_recovered_from_chunks_if_no_group_provided() { // A request times out with `Unavailable` error. assert_eq!(rx.await.unwrap().unwrap_err(), RecoveryError::Unavailable); - (virtual_overseer, req_cfg) + virtual_overseer }); } #[test] fn availability_is_recovered_from_chunks_even_if_backing_group_supplied_if_chunks_only() { let test_state = TestState::default(); + let req_protocol_names = ReqProtocolNames::new(&GENESIS_HASH, None); + let subsystem = AvailabilityRecoverySubsystem::with_chunks_only( + request_receiver(&req_protocol_names), + Metrics::new_dummy(), + ); - test_harness_chunks_only(|mut virtual_overseer, req_cfg| async move { + test_harness(subsystem, |mut virtual_overseer| async move { overseer_signal( &mut virtual_overseer, OverseerSignal::ActiveLeaves(ActiveLeavesUpdate::start_work(new_leaf( @@ -649,6 +597,7 @@ fn availability_is_recovered_from_chunks_even_if_backing_group_supplied_if_chunk test_state .test_chunk_requests( + &req_protocol_names, candidate_hash, &mut virtual_overseer, test_state.threshold(), @@ -684,6 +633,7 @@ fn availability_is_recovered_from_chunks_even_if_backing_group_supplied_if_chunk test_state .test_chunk_requests( + &req_protocol_names, new_candidate.hash(), &mut virtual_overseer, test_state.impossibility_threshold(), @@ -693,15 +643,20 @@ fn availability_is_recovered_from_chunks_even_if_backing_group_supplied_if_chunk // A request times out with `Unavailable` error. assert_eq!(rx.await.unwrap().unwrap_err(), RecoveryError::Unavailable); - (virtual_overseer, req_cfg) + virtual_overseer }); } #[test] fn bad_merkle_path_leads_to_recovery_error() { let mut test_state = TestState::default(); + let req_protocol_names = ReqProtocolNames::new(&GENESIS_HASH, None); + let subsystem = AvailabilityRecoverySubsystem::with_fast_path( + request_receiver(&req_protocol_names), + Metrics::new_dummy(), + ); - test_harness_fast_path(|mut virtual_overseer, req_cfg| async move { + test_harness(subsystem, |mut virtual_overseer| async move { overseer_signal( &mut virtual_overseer, OverseerSignal::ActiveLeaves(ActiveLeavesUpdate::start_work(new_leaf( @@ -740,6 +695,7 @@ fn bad_merkle_path_leads_to_recovery_error() { test_state .test_chunk_requests( + &req_protocol_names, candidate_hash, &mut virtual_overseer, test_state.impossibility_threshold(), @@ -749,15 +705,20 @@ fn bad_merkle_path_leads_to_recovery_error() { // A request times out with `Unavailable` error. assert_eq!(rx.await.unwrap().unwrap_err(), RecoveryError::Unavailable); - (virtual_overseer, req_cfg) + virtual_overseer }); } #[test] fn wrong_chunk_index_leads_to_recovery_error() { let mut test_state = TestState::default(); + let req_protocol_names = ReqProtocolNames::new(&GENESIS_HASH, None); + let subsystem = AvailabilityRecoverySubsystem::with_fast_path( + request_receiver(&req_protocol_names), + Metrics::new_dummy(), + ); - test_harness_fast_path(|mut virtual_overseer, req_cfg| async move { + test_harness(subsystem, |mut virtual_overseer| async move { overseer_signal( &mut virtual_overseer, OverseerSignal::ActiveLeaves(ActiveLeavesUpdate::start_work(new_leaf( @@ -796,6 +757,7 @@ fn wrong_chunk_index_leads_to_recovery_error() { test_state .test_chunk_requests( + &req_protocol_names, candidate_hash, &mut virtual_overseer, test_state.impossibility_threshold(), @@ -805,15 +767,20 @@ fn wrong_chunk_index_leads_to_recovery_error() { // A request times out with `Unavailable` error as there are no good peers. assert_eq!(rx.await.unwrap().unwrap_err(), RecoveryError::Unavailable); - (virtual_overseer, req_cfg) + virtual_overseer }); } #[test] fn invalid_erasure_coding_leads_to_invalid_error() { let mut test_state = TestState::default(); + let req_protocol_names = ReqProtocolNames::new(&GENESIS_HASH, None); + let subsystem = AvailabilityRecoverySubsystem::with_fast_path( + request_receiver(&req_protocol_names), + Metrics::new_dummy(), + ); - test_harness_fast_path(|mut virtual_overseer, req_cfg| async move { + test_harness(subsystem, |mut virtual_overseer| async move { let pov = PoV { block_data: BlockData(vec![69; 64]) }; let (bad_chunks, bad_erasure_root) = derive_erasure_chunks_with_proofs_and_root( @@ -859,6 +826,7 @@ fn invalid_erasure_coding_leads_to_invalid_error() { test_state .test_chunk_requests( + &req_protocol_names, candidate_hash, &mut virtual_overseer, test_state.threshold(), @@ -868,15 +836,20 @@ fn invalid_erasure_coding_leads_to_invalid_error() { // f+1 'valid' chunks can't produce correct data. assert_eq!(rx.await.unwrap().unwrap_err(), RecoveryError::Invalid); - (virtual_overseer, req_cfg) + virtual_overseer }); } #[test] fn fast_path_backing_group_recovers() { let test_state = TestState::default(); + let req_protocol_names = ReqProtocolNames::new(&GENESIS_HASH, None); + let subsystem = AvailabilityRecoverySubsystem::with_fast_path( + request_receiver(&req_protocol_names), + Metrics::new_dummy(), + ); - test_harness_fast_path(|mut virtual_overseer, req_cfg| async move { + test_harness(subsystem, |mut virtual_overseer| async move { overseer_signal( &mut virtual_overseer, OverseerSignal::ActiveLeaves(ActiveLeavesUpdate::start_work(new_leaf( @@ -911,20 +884,30 @@ fn fast_path_backing_group_recovers() { test_state.respond_to_available_data_query(&mut virtual_overseer, false).await; test_state - .test_full_data_requests(candidate_hash, &mut virtual_overseer, who_has) + .test_full_data_requests( + &req_protocol_names, + candidate_hash, + &mut virtual_overseer, + who_has, + ) .await; // Recovered data should match the original one. assert_eq!(rx.await.unwrap().unwrap(), test_state.available_data); - (virtual_overseer, req_cfg) + virtual_overseer }); } #[test] fn recovers_from_only_chunks_if_pov_large() { let test_state = TestState::default(); + let req_protocol_names = ReqProtocolNames::new(&GENESIS_HASH, None); + let subsystem = AvailabilityRecoverySubsystem::with_chunks_if_pov_large( + request_receiver(&req_protocol_names), + Metrics::new_dummy(), + ); - test_harness_chunks_if_pov_large(|mut virtual_overseer, req_cfg| async move { + test_harness(subsystem, |mut virtual_overseer| async move { overseer_signal( &mut virtual_overseer, OverseerSignal::ActiveLeaves(ActiveLeavesUpdate::start_work(new_leaf( @@ -965,6 +948,7 @@ fn recovers_from_only_chunks_if_pov_large() { test_state .test_chunk_requests( + &req_protocol_names, candidate_hash, &mut virtual_overseer, test_state.threshold(), @@ -1009,6 +993,7 @@ fn recovers_from_only_chunks_if_pov_large() { test_state .test_chunk_requests( + &req_protocol_names, new_candidate.hash(), &mut virtual_overseer, test_state.impossibility_threshold(), @@ -1018,15 +1003,20 @@ fn recovers_from_only_chunks_if_pov_large() { // A request times out with `Unavailable` error. assert_eq!(rx.await.unwrap().unwrap_err(), RecoveryError::Unavailable); - (virtual_overseer, req_cfg) + virtual_overseer }); } #[test] fn fast_path_backing_group_recovers_if_pov_small() { let test_state = TestState::default(); + let req_protocol_names = ReqProtocolNames::new(&GENESIS_HASH, None); + let subsystem = AvailabilityRecoverySubsystem::with_chunks_if_pov_large( + request_receiver(&req_protocol_names), + Metrics::new_dummy(), + ); - test_harness_chunks_if_pov_large(|mut virtual_overseer, req_cfg| async move { + test_harness(subsystem, |mut virtual_overseer| async move { overseer_signal( &mut virtual_overseer, OverseerSignal::ActiveLeaves(ActiveLeavesUpdate::start_work(new_leaf( @@ -1070,20 +1060,30 @@ fn fast_path_backing_group_recovers_if_pov_small() { test_state.respond_to_available_data_query(&mut virtual_overseer, false).await; test_state - .test_full_data_requests(candidate_hash, &mut virtual_overseer, who_has) + .test_full_data_requests( + &req_protocol_names, + candidate_hash, + &mut virtual_overseer, + who_has, + ) .await; // Recovered data should match the original one. assert_eq!(rx.await.unwrap().unwrap(), test_state.available_data); - (virtual_overseer, req_cfg) + virtual_overseer }); } #[test] fn no_answers_in_fast_path_causes_chunk_requests() { let test_state = TestState::default(); + let req_protocol_names = ReqProtocolNames::new(&GENESIS_HASH, None); + let subsystem = AvailabilityRecoverySubsystem::with_fast_path( + request_receiver(&req_protocol_names), + Metrics::new_dummy(), + ); - test_harness_fast_path(|mut virtual_overseer, req_cfg| async move { + test_harness(subsystem, |mut virtual_overseer| async move { overseer_signal( &mut virtual_overseer, OverseerSignal::ActiveLeaves(ActiveLeavesUpdate::start_work(new_leaf( @@ -1119,13 +1119,19 @@ fn no_answers_in_fast_path_causes_chunk_requests() { test_state.respond_to_available_data_query(&mut virtual_overseer, false).await; test_state - .test_full_data_requests(candidate_hash, &mut virtual_overseer, who_has) + .test_full_data_requests( + &req_protocol_names, + candidate_hash, + &mut virtual_overseer, + who_has, + ) .await; test_state.respond_to_query_all_request(&mut virtual_overseer, |_| false).await; test_state .test_chunk_requests( + &req_protocol_names, candidate_hash, &mut virtual_overseer, test_state.threshold(), @@ -1135,15 +1141,20 @@ fn no_answers_in_fast_path_causes_chunk_requests() { // Recovered data should match the original one. assert_eq!(rx.await.unwrap().unwrap(), test_state.available_data); - (virtual_overseer, req_cfg) + virtual_overseer }); } #[test] fn task_canceled_when_receivers_dropped() { let test_state = TestState::default(); + let req_protocol_names = ReqProtocolNames::new(&GENESIS_HASH, None); + let subsystem = AvailabilityRecoverySubsystem::with_chunks_only( + request_receiver(&req_protocol_names), + Metrics::new_dummy(), + ); - test_harness_chunks_only(|mut virtual_overseer, req_cfg| async move { + test_harness(subsystem, |mut virtual_overseer| async move { overseer_signal( &mut virtual_overseer, OverseerSignal::ActiveLeaves(ActiveLeavesUpdate::start_work(new_leaf( @@ -1170,7 +1181,7 @@ fn task_canceled_when_receivers_dropped() { for _ in 0..test_state.validators.len() { match virtual_overseer.recv().timeout(TIMEOUT).await { - None => return (virtual_overseer, req_cfg), + None => return virtual_overseer, Some(_) => continue, } } @@ -1182,8 +1193,13 @@ fn task_canceled_when_receivers_dropped() { #[test] fn chunks_retry_until_all_nodes_respond() { let test_state = TestState::default(); + let req_protocol_names = ReqProtocolNames::new(&GENESIS_HASH, None); + let subsystem = AvailabilityRecoverySubsystem::with_chunks_only( + request_receiver(&req_protocol_names), + Metrics::new_dummy(), + ); - test_harness_chunks_only(|mut virtual_overseer, req_cfg| async move { + test_harness(subsystem, |mut virtual_overseer| async move { overseer_signal( &mut virtual_overseer, OverseerSignal::ActiveLeaves(ActiveLeavesUpdate::start_work(new_leaf( @@ -1215,6 +1231,7 @@ fn chunks_retry_until_all_nodes_respond() { test_state .test_chunk_requests( + &req_protocol_names, candidate_hash, &mut virtual_overseer, test_state.validators.len() - test_state.threshold(), @@ -1225,6 +1242,7 @@ fn chunks_retry_until_all_nodes_respond() { // we get to go another round! test_state .test_chunk_requests( + &req_protocol_names, candidate_hash, &mut virtual_overseer, test_state.impossibility_threshold(), @@ -1234,15 +1252,20 @@ fn chunks_retry_until_all_nodes_respond() { // Recovered data should match the original one. assert_eq!(rx.await.unwrap().unwrap_err(), RecoveryError::Unavailable); - (virtual_overseer, req_cfg) + virtual_overseer }); } #[test] fn not_returning_requests_wont_stall_retrieval() { let test_state = TestState::default(); + let req_protocol_names = ReqProtocolNames::new(&GENESIS_HASH, None); + let subsystem = AvailabilityRecoverySubsystem::with_chunks_only( + request_receiver(&req_protocol_names), + Metrics::new_dummy(), + ); - test_harness_chunks_only(|mut virtual_overseer, req_cfg| async move { + test_harness(subsystem, |mut virtual_overseer| async move { overseer_signal( &mut virtual_overseer, OverseerSignal::ActiveLeaves(ActiveLeavesUpdate::start_work(new_leaf( @@ -1277,13 +1300,18 @@ fn not_returning_requests_wont_stall_retrieval() { // Not returning senders won't cause the retrieval to stall: let _senders = test_state - .test_chunk_requests(candidate_hash, &mut virtual_overseer, not_returning_count, |_| { - Has::DoesNotReturn - }) + .test_chunk_requests( + &req_protocol_names, + candidate_hash, + &mut virtual_overseer, + not_returning_count, + |_| Has::DoesNotReturn, + ) .await; test_state .test_chunk_requests( + &req_protocol_names, candidate_hash, &mut virtual_overseer, // Should start over: @@ -1295,6 +1323,7 @@ fn not_returning_requests_wont_stall_retrieval() { // we get to go another round! test_state .test_chunk_requests( + &req_protocol_names, candidate_hash, &mut virtual_overseer, test_state.threshold(), @@ -1304,15 +1333,20 @@ fn not_returning_requests_wont_stall_retrieval() { // Recovered data should match the original one: assert_eq!(rx.await.unwrap().unwrap(), test_state.available_data); - (virtual_overseer, req_cfg) + virtual_overseer }); } #[test] fn all_not_returning_requests_still_recovers_on_return() { let test_state = TestState::default(); + let req_protocol_names = ReqProtocolNames::new(&GENESIS_HASH, None); + let subsystem = AvailabilityRecoverySubsystem::with_chunks_only( + request_receiver(&req_protocol_names), + Metrics::new_dummy(), + ); - test_harness_chunks_only(|mut virtual_overseer, req_cfg| async move { + test_harness(subsystem, |mut virtual_overseer| async move { overseer_signal( &mut virtual_overseer, OverseerSignal::ActiveLeaves(ActiveLeavesUpdate::start_work(new_leaf( @@ -1344,6 +1378,7 @@ fn all_not_returning_requests_still_recovers_on_return() { let senders = test_state .test_chunk_requests( + &req_protocol_names, candidate_hash, &mut virtual_overseer, test_state.validators.len(), @@ -1358,6 +1393,7 @@ fn all_not_returning_requests_still_recovers_on_return() { std::mem::drop(senders); }, test_state.test_chunk_requests( + &req_protocol_names, candidate_hash, &mut virtual_overseer, // Should start over: @@ -1370,6 +1406,7 @@ fn all_not_returning_requests_still_recovers_on_return() { // we get to go another round! test_state .test_chunk_requests( + &req_protocol_names, candidate_hash, &mut virtual_overseer, test_state.threshold(), @@ -1379,15 +1416,20 @@ fn all_not_returning_requests_still_recovers_on_return() { // Recovered data should match the original one: assert_eq!(rx.await.unwrap().unwrap(), test_state.available_data); - (virtual_overseer, req_cfg) + virtual_overseer }); } #[test] fn returns_early_if_we_have_the_data() { let test_state = TestState::default(); + let req_protocol_names = ReqProtocolNames::new(&GENESIS_HASH, None); + let subsystem = AvailabilityRecoverySubsystem::with_chunks_only( + request_receiver(&req_protocol_names), + Metrics::new_dummy(), + ); - test_harness_chunks_only(|mut virtual_overseer, req_cfg| async move { + test_harness(subsystem, |mut virtual_overseer| async move { overseer_signal( &mut virtual_overseer, OverseerSignal::ActiveLeaves(ActiveLeavesUpdate::start_work(new_leaf( @@ -1414,15 +1456,20 @@ fn returns_early_if_we_have_the_data() { test_state.respond_to_available_data_query(&mut virtual_overseer, true).await; assert_eq!(rx.await.unwrap().unwrap(), test_state.available_data); - (virtual_overseer, req_cfg) + virtual_overseer }); } #[test] fn does_not_query_local_validator() { let test_state = TestState::default(); + let req_protocol_names = ReqProtocolNames::new(&GENESIS_HASH, None); + let subsystem = AvailabilityRecoverySubsystem::with_chunks_only( + request_receiver(&req_protocol_names), + Metrics::new_dummy(), + ); - test_harness_chunks_only(|mut virtual_overseer, req_cfg| async move { + test_harness(subsystem, |mut virtual_overseer| async move { overseer_signal( &mut virtual_overseer, OverseerSignal::ActiveLeaves(ActiveLeavesUpdate::start_work(new_leaf( @@ -1453,6 +1500,7 @@ fn does_not_query_local_validator() { test_state .test_chunk_requests( + &req_protocol_names, candidate_hash, &mut virtual_overseer, test_state.validators.len(), @@ -1463,6 +1511,7 @@ fn does_not_query_local_validator() { // second round, make sure it uses the local chunk. test_state .test_chunk_requests( + &req_protocol_names, candidate_hash, &mut virtual_overseer, test_state.threshold() - 1, @@ -1471,15 +1520,20 @@ fn does_not_query_local_validator() { .await; assert_eq!(rx.await.unwrap().unwrap(), test_state.available_data); - (virtual_overseer, req_cfg) + virtual_overseer }); } #[test] fn invalid_local_chunk_is_ignored() { let test_state = TestState::default(); + let req_protocol_names = ReqProtocolNames::new(&GENESIS_HASH, None); + let subsystem = AvailabilityRecoverySubsystem::with_chunks_only( + request_receiver(&req_protocol_names), + Metrics::new_dummy(), + ); - test_harness_chunks_only(|mut virtual_overseer, req_cfg| async move { + test_harness(subsystem, |mut virtual_overseer| async move { overseer_signal( &mut virtual_overseer, OverseerSignal::ActiveLeaves(ActiveLeavesUpdate::start_work(new_leaf( @@ -1512,6 +1566,7 @@ fn invalid_local_chunk_is_ignored() { test_state .test_chunk_requests( + &req_protocol_names, candidate_hash, &mut virtual_overseer, test_state.threshold() - 1, @@ -1520,6 +1575,6 @@ fn invalid_local_chunk_is_ignored() { .await; assert_eq!(rx.await.unwrap().unwrap(), test_state.available_data); - (virtual_overseer, req_cfg) + virtual_overseer }); } diff --git a/polkadot/node/network/bridge/src/network.rs b/polkadot/node/network/bridge/src/network.rs index 2fcf5cec489d..21bed019256a 100644 --- a/polkadot/node/network/bridge/src/network.rs +++ b/polkadot/node/network/bridge/src/network.rs @@ -264,7 +264,8 @@ impl Network for Arc> { req_protocol_names: &ReqProtocolNames, if_disconnected: IfDisconnected, ) { - let (protocol, OutgoingRequest { peer, payload, pending_response }) = req.encode_request(); + let (protocol, OutgoingRequest { peer, payload, pending_response, fallback_request }) = + req.encode_request(); let peer_id = match peer { Recipient::Peer(peer_id) => Some(peer_id), @@ -315,6 +316,7 @@ impl Network for Arc> { target: LOG_TARGET, %peer_id, protocol = %req_protocol_names.get_name(protocol), + fallback_protocol = ?fallback_request.as_ref().map(|(_, p)| req_protocol_names.get_name(*p)), ?if_disconnected, "Starting request", ); @@ -324,6 +326,7 @@ impl Network for Arc> { peer_id, req_protocol_names.get_name(protocol), payload, + fallback_request.map(|(r, p)| (r, req_protocol_names.get_name(p))), pending_response, if_disconnected, ); diff --git a/polkadot/node/network/collator-protocol/src/validator_side/tests/mod.rs b/polkadot/node/network/collator-protocol/src/validator_side/tests/mod.rs index 3a9740149948..1ba6389212cc 100644 --- a/polkadot/node/network/collator-protocol/src/validator_side/tests/mod.rs +++ b/polkadot/node/network/collator-protocol/src/validator_side/tests/mod.rs @@ -17,6 +17,7 @@ use super::*; use assert_matches::assert_matches; use futures::{executor, future, Future}; +use sc_network::ProtocolName; use sp_core::{crypto::Pair, Encode}; use sp_keyring::Sr25519Keyring; use sp_keystore::Keystore; @@ -559,11 +560,11 @@ fn act_on_advertisement_v2() { .await; response_channel - .send(Ok(request_v1::CollationFetchingResponse::Collation( - candidate_a.clone(), - pov.clone(), - ) - .encode())) + .send(Ok(( + request_v1::CollationFetchingResponse::Collation(candidate_a.clone(), pov.clone()) + .encode(), + ProtocolName::from(""), + ))) .expect("Sending response should succeed"); assert_candidate_backing_second( @@ -761,11 +762,11 @@ fn fetch_one_collation_at_a_time() { candidate_a.descriptor.relay_parent = test_state.relay_parent; candidate_a.descriptor.persisted_validation_data_hash = dummy_pvd().hash(); response_channel - .send(Ok(request_v1::CollationFetchingResponse::Collation( - candidate_a.clone(), - pov.clone(), - ) - .encode())) + .send(Ok(( + request_v1::CollationFetchingResponse::Collation(candidate_a.clone(), pov.clone()) + .encode(), + ProtocolName::from(""), + ))) .expect("Sending response should succeed"); assert_candidate_backing_second( @@ -885,19 +886,19 @@ fn fetches_next_collation() { // First request finishes now: response_channel_non_exclusive - .send(Ok(request_v1::CollationFetchingResponse::Collation( - candidate_a.clone(), - pov.clone(), - ) - .encode())) + .send(Ok(( + request_v1::CollationFetchingResponse::Collation(candidate_a.clone(), pov.clone()) + .encode(), + ProtocolName::from(""), + ))) .expect("Sending response should succeed"); response_channel - .send(Ok(request_v1::CollationFetchingResponse::Collation( - candidate_a.clone(), - pov.clone(), - ) - .encode())) + .send(Ok(( + request_v1::CollationFetchingResponse::Collation(candidate_a.clone(), pov.clone()) + .encode(), + ProtocolName::from(""), + ))) .expect("Sending response should succeed"); assert_candidate_backing_second( @@ -1023,11 +1024,11 @@ fn fetch_next_collation_on_invalid_collation() { candidate_a.descriptor.relay_parent = test_state.relay_parent; candidate_a.descriptor.persisted_validation_data_hash = dummy_pvd().hash(); response_channel - .send(Ok(request_v1::CollationFetchingResponse::Collation( - candidate_a.clone(), - pov.clone(), - ) - .encode())) + .send(Ok(( + request_v1::CollationFetchingResponse::Collation(candidate_a.clone(), pov.clone()) + .encode(), + ProtocolName::from(""), + ))) .expect("Sending response should succeed"); let receipt = assert_candidate_backing_second( diff --git a/polkadot/node/network/collator-protocol/src/validator_side/tests/prospective_parachains.rs b/polkadot/node/network/collator-protocol/src/validator_side/tests/prospective_parachains.rs index c5236ef3eb21..23963e65554e 100644 --- a/polkadot/node/network/collator-protocol/src/validator_side/tests/prospective_parachains.rs +++ b/polkadot/node/network/collator-protocol/src/validator_side/tests/prospective_parachains.rs @@ -314,11 +314,11 @@ fn v1_advertisement_accepted_and_seconded() { let pov = PoV { block_data: BlockData(vec![1]) }; response_channel - .send(Ok(request_v2::CollationFetchingResponse::Collation( - candidate.clone(), - pov.clone(), - ) - .encode())) + .send(Ok(( + request_v2::CollationFetchingResponse::Collation(candidate.clone(), pov.clone()) + .encode(), + ProtocolName::from(""), + ))) .expect("Sending response should succeed"); assert_candidate_backing_second( @@ -565,11 +565,14 @@ fn second_multiple_candidates_per_relay_parent() { let pov = PoV { block_data: BlockData(vec![1]) }; response_channel - .send(Ok(request_v2::CollationFetchingResponse::Collation( - candidate.clone(), - pov.clone(), - ) - .encode())) + .send(Ok(( + request_v2::CollationFetchingResponse::Collation( + candidate.clone(), + pov.clone(), + ) + .encode(), + ProtocolName::from(""), + ))) .expect("Sending response should succeed"); assert_candidate_backing_second( @@ -717,11 +720,11 @@ fn fetched_collation_sanity_check() { let pov = PoV { block_data: BlockData(vec![1]) }; response_channel - .send(Ok(request_v2::CollationFetchingResponse::Collation( - candidate.clone(), - pov.clone(), - ) - .encode())) + .send(Ok(( + request_v2::CollationFetchingResponse::Collation(candidate.clone(), pov.clone()) + .encode(), + ProtocolName::from(""), + ))) .expect("Sending response should succeed"); // PVD request. diff --git a/polkadot/node/network/dispute-distribution/src/tests/mod.rs b/polkadot/node/network/dispute-distribution/src/tests/mod.rs index a3520bf35f80..880d1b18032c 100644 --- a/polkadot/node/network/dispute-distribution/src/tests/mod.rs +++ b/polkadot/node/network/dispute-distribution/src/tests/mod.rs @@ -32,7 +32,7 @@ use futures::{ use futures_timer::Delay; use parity_scale_codec::{Decode, Encode}; -use sc_network::config::RequestResponseConfig; +use sc_network::{config::RequestResponseConfig, ProtocolName}; use polkadot_node_network_protocol::{ request_response::{v1::DisputeRequest, IncomingRequest, ReqProtocolNames}, @@ -832,7 +832,7 @@ async fn check_sent_requests( if confirm_receive { for req in reqs { req.pending_response.send( - Ok(DisputeResponse::Confirmed.encode()) + Ok((DisputeResponse::Confirmed.encode(), ProtocolName::from(""))) ) .expect("Subsystem should be listening for a response."); } diff --git a/polkadot/node/network/protocol/src/request_response/mod.rs b/polkadot/node/network/protocol/src/request_response/mod.rs index 2df3021343df..a67d83aff0c9 100644 --- a/polkadot/node/network/protocol/src/request_response/mod.rs +++ b/polkadot/node/network/protocol/src/request_response/mod.rs @@ -30,7 +30,24 @@ //! `trait IsRequest` .... A trait describing a particular request. It is used for gathering meta //! data, like what is the corresponding response type. //! -//! Versioned (v1 module): The actual requests and responses as sent over the network. +//! ## Versioning +//! +//! Versioning for request-response protocols can be done in multiple ways. +//! +//! If you're just changing the protocol name but the binary payloads are the same, just add a new +//! `fallback_name` to the protocol config. +//! +//! One way in which versioning has historically been achieved for req-response protocols is to +//! bundle the new req-resp version with an upgrade of a notifications protocol. The subsystem would +//! then know which request version to use based on stored data about the peer's notifications +//! protocol version. +//! +//! When bumping a notifications protocol version is not needed/desirable, you may add a new +//! req-resp protocol and set the old request as a fallback (see +//! `OutgoingRequest::new_with_fallback`). A request with the new version will be attempted and if +//! the protocol is refused by the peer, the fallback protocol request will be used. +//! Information about the actually used protocol will be returned alongside the raw response, so +//! that you know how to decode it. use std::{collections::HashMap, time::Duration, u64}; @@ -188,11 +205,11 @@ impl Protocol { tx: Option>, ) -> RequestResponseConfig { let name = req_protocol_names.get_name(self); - let fallback_names = self.get_fallback_names(); + let legacy_names = self.get_legacy_name().into_iter().map(Into::into).collect(); match self { Protocol::ChunkFetchingV1 => RequestResponseConfig { name, - fallback_names, + fallback_names: legacy_names, max_request_size: 1_000, max_response_size: POV_RESPONSE_SIZE as u64 * 3, // We are connected to all validators: @@ -202,7 +219,7 @@ impl Protocol { Protocol::CollationFetchingV1 | Protocol::CollationFetchingV2 => RequestResponseConfig { name, - fallback_names, + fallback_names: legacy_names, max_request_size: 1_000, max_response_size: POV_RESPONSE_SIZE, // Taken from initial implementation in collator protocol: @@ -211,7 +228,7 @@ impl Protocol { }, Protocol::PoVFetchingV1 => RequestResponseConfig { name, - fallback_names, + fallback_names: legacy_names, max_request_size: 1_000, max_response_size: POV_RESPONSE_SIZE, request_timeout: POV_REQUEST_TIMEOUT_CONNECTED, @@ -219,7 +236,7 @@ impl Protocol { }, Protocol::AvailableDataFetchingV1 => RequestResponseConfig { name, - fallback_names, + fallback_names: legacy_names, max_request_size: 1_000, // Available data size is dominated by the PoV size. max_response_size: POV_RESPONSE_SIZE, @@ -228,7 +245,7 @@ impl Protocol { }, Protocol::StatementFetchingV1 => RequestResponseConfig { name, - fallback_names, + fallback_names: legacy_names, max_request_size: 1_000, // Available data size is dominated code size. max_response_size: STATEMENT_RESPONSE_SIZE, @@ -246,7 +263,7 @@ impl Protocol { }, Protocol::DisputeSendingV1 => RequestResponseConfig { name, - fallback_names, + fallback_names: legacy_names, max_request_size: 1_000, // Responses are just confirmation, in essence not even a bit. So 100 seems // plenty. @@ -256,7 +273,7 @@ impl Protocol { }, Protocol::AttestedCandidateV2 => RequestResponseConfig { name, - fallback_names, + fallback_names: legacy_names, max_request_size: 1_000, max_response_size: ATTESTED_CANDIDATE_RESPONSE_SIZE, request_timeout: ATTESTED_CANDIDATE_TIMEOUT, @@ -328,12 +345,9 @@ impl Protocol { } } - /// Fallback protocol names of this protocol, as understood by substrate networking. - fn get_fallback_names(self) -> Vec { - self.get_legacy_name().into_iter().map(Into::into).collect() - } - /// Legacy protocol name associated with each peer set, if any. + /// The request will be tried on this legacy protocol name if the remote refuses to speak the + /// protocol. const fn get_legacy_name(self) -> Option<&'static str> { match self { Protocol::ChunkFetchingV1 => Some("/polkadot/req_chunk/1"), @@ -360,6 +374,7 @@ pub trait IsRequest { } /// Type for getting on the wire [`Protocol`] names using genesis hash & fork id. +#[derive(Clone)] pub struct ReqProtocolNames { names: HashMap, } diff --git a/polkadot/node/network/protocol/src/request_response/outgoing.rs b/polkadot/node/network/protocol/src/request_response/outgoing.rs index c613d5778f5e..88439ad40367 100644 --- a/polkadot/node/network/protocol/src/request_response/outgoing.rs +++ b/polkadot/node/network/protocol/src/request_response/outgoing.rs @@ -14,8 +14,9 @@ // You should have received a copy of the GNU General Public License // along with Polkadot. If not, see . -use futures::{channel::oneshot, prelude::Future}; +use futures::{channel::oneshot, prelude::Future, FutureExt}; +use network::ProtocolName; use parity_scale_codec::{Decode, Encode, Error as DecodingError}; use sc_network as network; @@ -49,20 +50,6 @@ pub enum Requests { } impl Requests { - /// Get the protocol this request conforms to. - pub fn get_protocol(&self) -> Protocol { - match self { - Self::ChunkFetchingV1(_) => Protocol::ChunkFetchingV1, - Self::CollationFetchingV1(_) => Protocol::CollationFetchingV1, - Self::CollationFetchingV2(_) => Protocol::CollationFetchingV2, - Self::PoVFetchingV1(_) => Protocol::PoVFetchingV1, - Self::AvailableDataFetchingV1(_) => Protocol::AvailableDataFetchingV1, - Self::StatementFetchingV1(_) => Protocol::StatementFetchingV1, - Self::DisputeSendingV1(_) => Protocol::DisputeSendingV1, - Self::AttestedCandidateV2(_) => Protocol::AttestedCandidateV2, - } - } - /// Encode the request. /// /// The corresponding protocol is returned as well, as we are now leaving typed territory. @@ -85,7 +72,7 @@ impl Requests { } /// Used by the network to send us a response to a request. -pub type ResponseSender = oneshot::Sender, network::RequestFailure>>; +pub type ResponseSender = oneshot::Sender, ProtocolName), network::RequestFailure>>; /// Any error that can occur when sending a request. #[derive(Debug, thiserror::Error)] @@ -128,11 +115,13 @@ impl RequestError { /// When using `Recipient::Authority`, the addresses can be found thanks to the authority /// discovery system. #[derive(Debug)] -pub struct OutgoingRequest { +pub struct OutgoingRequest { /// Intended recipient of this request. pub peer: Recipient, /// The actual request to send over the wire. pub payload: Req, + /// Optional fallback request and protocol. + pub fallback_request: Option<(FallbackReq, Protocol)>, /// Sender which is used by networking to get us back a response. pub pending_response: ResponseSender, } @@ -149,10 +138,12 @@ pub enum Recipient { /// Responses received for an `OutgoingRequest`. pub type OutgoingResult = Result; -impl OutgoingRequest +impl OutgoingRequest where Req: IsRequest + Encode, Req::Response: Decode, + FallbackReq: IsRequest + Encode, + FallbackReq::Response: Decode, { /// Create a new `OutgoingRequest`. /// @@ -163,24 +154,54 @@ where payload: Req, ) -> (Self, impl Future>) { let (tx, rx) = oneshot::channel(); - let r = Self { peer, payload, pending_response: tx }; - (r, receive_response::(rx)) + let r = Self { peer, payload, pending_response: tx, fallback_request: None }; + (r, receive_response::(rx.map(|r| r.map(|r| r.map(|(resp, _)| resp))))) } + /// Create a new `OutgoingRequest` with a fallback in case the remote does not support this + /// protocol. Useful when adding a new version of a req-response protocol, to achieve + /// compatibility with the older version. + /// + /// Returns a raw `Vec` response over the channel. Use the associated `ProtocolName` to know + /// which request was the successful one and appropriately decode the response. + // WARNING: This is commented for now because it's not used yet. + // If you need it, make sure to test it. You may need to enable the V1 substream upgrade + // protocol, unless libp2p was in the meantime updated to a version that fixes the problem + // described in https://github.com/libp2p/rust-libp2p/issues/5074 + // pub fn new_with_fallback( + // peer: Recipient, + // payload: Req, + // fallback_request: FallbackReq, + // ) -> (Self, impl Future, ProtocolName)>>) { + // let (tx, rx) = oneshot::channel(); + // let r = Self { + // peer, + // payload, + // pending_response: tx, + // fallback_request: Some((fallback_request, FallbackReq::PROTOCOL)), + // }; + // (r, async { Ok(rx.await??) }) + // } + /// Encode a request into a `Vec`. /// /// As this throws away type information, we also return the `Protocol` this encoded request /// adheres to. pub fn encode_request(self) -> (Protocol, OutgoingRequest>) { - let OutgoingRequest { peer, payload, pending_response } = self; - let encoded = OutgoingRequest { peer, payload: payload.encode(), pending_response }; + let OutgoingRequest { peer, payload, pending_response, fallback_request } = self; + let encoded = OutgoingRequest { + peer, + payload: payload.encode(), + fallback_request: fallback_request.map(|(r, p)| (r.encode(), p)), + pending_response, + }; (Req::PROTOCOL, encoded) } } /// Future for actually receiving a typed response for an `OutgoingRequest`. async fn receive_response( - rec: oneshot::Receiver, network::RequestFailure>>, + rec: impl Future, network::RequestFailure>, oneshot::Canceled>>, ) -> OutgoingResult where Req: IsRequest, diff --git a/polkadot/node/network/statement-distribution/src/legacy_v1/tests.rs b/polkadot/node/network/statement-distribution/src/legacy_v1/tests.rs index 8ac9895ec5ad..2766ec9815af 100644 --- a/polkadot/node/network/statement-distribution/src/legacy_v1/tests.rs +++ b/polkadot/node/network/statement-distribution/src/legacy_v1/tests.rs @@ -50,6 +50,7 @@ use polkadot_primitives_test_helpers::{ dummy_committed_candidate_receipt, dummy_hash, AlwaysZeroRng, }; use sc_keystore::LocalKeystore; +use sc_network::ProtocolName; use sp_application_crypto::{sr25519::Pair, AppCrypto, Pair as TraitPair}; use sp_authority_discovery::AuthorityPair; use sp_keyring::Sr25519Keyring; @@ -1330,7 +1331,7 @@ fn receiving_large_statement_from_one_sends_to_another_and_to_candidate_backing( bad }; let response = StatementFetchingResponse::Statement(bad_candidate); - outgoing.pending_response.send(Ok(response.encode())).unwrap(); + outgoing.pending_response.send(Ok((response.encode(), ProtocolName::from("")))).unwrap(); } ); @@ -1382,7 +1383,7 @@ fn receiving_large_statement_from_one_sends_to_another_and_to_candidate_backing( // On retry, we should have reverse order: assert_eq!(outgoing.peer, Recipient::Peer(peer_c)); let response = StatementFetchingResponse::Statement(candidate.clone()); - outgoing.pending_response.send(Ok(response.encode())).unwrap(); + outgoing.pending_response.send(Ok((response.encode(), ProtocolName::from("")))).unwrap(); } ); @@ -1869,7 +1870,7 @@ fn delay_reputation_changes() { bad }; let response = StatementFetchingResponse::Statement(bad_candidate); - outgoing.pending_response.send(Ok(response.encode())).unwrap(); + outgoing.pending_response.send(Ok((response.encode(), ProtocolName::from("")))).unwrap(); } ); @@ -1913,7 +1914,7 @@ fn delay_reputation_changes() { // On retry, we should have reverse order: assert_eq!(outgoing.peer, Recipient::Peer(peer_c)); let response = StatementFetchingResponse::Statement(candidate.clone()); - outgoing.pending_response.send(Ok(response.encode())).unwrap(); + outgoing.pending_response.send(Ok((response.encode(), ProtocolName::from("")))).unwrap(); } ); diff --git a/polkadot/node/network/statement-distribution/src/v2/tests/mod.rs b/polkadot/node/network/statement-distribution/src/v2/tests/mod.rs index 3ce43202b954..bb780584febf 100644 --- a/polkadot/node/network/statement-distribution/src/v2/tests/mod.rs +++ b/polkadot/node/network/statement-distribution/src/v2/tests/mod.rs @@ -38,6 +38,7 @@ use polkadot_primitives::{ SessionIndex, SessionInfo, ValidatorPair, }; use sc_keystore::LocalKeystore; +use sc_network::ProtocolName; use sp_application_crypto::Pair as PairT; use sp_authority_discovery::AuthorityPair as AuthorityDiscoveryPair; use sp_keyring::Sr25519Keyring; @@ -684,7 +685,7 @@ async fn handle_sent_request( persisted_validation_data, statements, }; - outgoing.pending_response.send(Ok(res.encode())).unwrap(); + outgoing.pending_response.send(Ok((res.encode(), ProtocolName::from("")))).unwrap(); } ); } diff --git a/polkadot/node/network/statement-distribution/src/v2/tests/requests.rs b/polkadot/node/network/statement-distribution/src/v2/tests/requests.rs index 04934b31482e..dc2c8f55290b 100644 --- a/polkadot/node/network/statement-distribution/src/v2/tests/requests.rs +++ b/polkadot/node/network/statement-distribution/src/v2/tests/requests.rs @@ -22,8 +22,9 @@ use polkadot_node_network_protocol::{ request_response::v2 as request_v2, v2::BackedCandidateManifest, }; use polkadot_primitives_test_helpers::make_candidate; -use sc_network::config::{ - IncomingRequest as RawIncomingRequest, OutgoingResponse as RawOutgoingResponse, +use sc_network::{ + config::{IncomingRequest as RawIncomingRequest, OutgoingResponse as RawOutgoingResponse}, + ProtocolName, }; #[test] @@ -1342,7 +1343,7 @@ fn when_validator_disabled_after_sending_the_request() { persisted_validation_data: pvd, statements, }; - outgoing.pending_response.send(Ok(res.encode())).unwrap(); + outgoing.pending_response.send(Ok((res.encode(), ProtocolName::from("")))).unwrap(); } ); } diff --git a/polkadot/node/subsystem-bench/src/availability/mod.rs b/polkadot/node/subsystem-bench/src/availability/mod.rs index 7c81b9313659..faedccdf3e42 100644 --- a/polkadot/node/subsystem-bench/src/availability/mod.rs +++ b/polkadot/node/subsystem-bench/src/availability/mod.rs @@ -109,7 +109,12 @@ fn prepare_test_inner( chunks: state.chunks.clone(), }; - let network = NetworkEmulator::new(&config, &dependencies, &test_authorities); + let req_protocol_names = ReqProtocolNames::new(GENESIS_HASH, None); + let (collation_req_receiver, req_cfg) = + IncomingRequest::get_config_receiver(&req_protocol_names); + + let network = + NetworkEmulator::new(&config, &dependencies, &test_authorities, req_protocol_names); let network_bridge_tx = network_bridge::MockNetworkBridgeTx::new( config.clone(), @@ -122,9 +127,6 @@ fn prepare_test_inner( _ => panic!("Unexpected objective"), }; - let (collation_req_receiver, req_cfg) = - IncomingRequest::get_config_receiver(&ReqProtocolNames::new(GENESIS_HASH, None)); - let subsystem = if use_fast_path { AvailabilityRecoverySubsystem::with_fast_path( collation_req_receiver, diff --git a/polkadot/node/subsystem-bench/src/core/mock/network_bridge.rs b/polkadot/node/subsystem-bench/src/core/mock/network_bridge.rs index b106b832011a..5d534e37c991 100644 --- a/polkadot/node/subsystem-bench/src/core/mock/network_bridge.rs +++ b/polkadot/node/subsystem-bench/src/core/mock/network_bridge.rs @@ -33,7 +33,9 @@ use polkadot_node_subsystem::{ }; use polkadot_node_network_protocol::request_response::{ - self as req_res, v1::ChunkResponse, Requests, + self as req_res, + v1::{AvailableDataFetchingRequest, ChunkFetchingRequest, ChunkResponse}, + IsRequest, Requests, }; use polkadot_primitives::AuthorityDiscoveryId; @@ -144,7 +146,10 @@ impl MockNetworkBridgeTx { size = 0; Err(RequestFailure::Network(OutboundFailure::ConnectionClosed)) } else { - Ok(req_res::v1::ChunkFetchingResponse::from(Some(chunk)).encode()) + Ok(( + req_res::v1::ChunkFetchingResponse::from(Some(chunk)).encode(), + self.network.req_protocol_names().get_name(ChunkFetchingRequest::PROTOCOL), + )) }; let authority_discovery_id_clone = authority_discovery_id.clone(); @@ -212,8 +217,13 @@ impl MockNetworkBridgeTx { let response = if random_error(self.config.error) { Err(RequestFailure::Network(OutboundFailure::ConnectionClosed)) } else { - Ok(req_res::v1::AvailableDataFetchingResponse::from(Some(available_data)) - .encode()) + Ok(( + req_res::v1::AvailableDataFetchingResponse::from(Some(available_data)) + .encode(), + self.network + .req_protocol_names() + .get_name(AvailableDataFetchingRequest::PROTOCOL), + )) }; let future = async move { diff --git a/polkadot/node/subsystem-bench/src/core/network.rs b/polkadot/node/subsystem-bench/src/core/network.rs index c4e20b421d34..bbf61425f73d 100644 --- a/polkadot/node/subsystem-bench/src/core/network.rs +++ b/polkadot/node/subsystem-bench/src/core/network.rs @@ -19,6 +19,7 @@ use super::{ *, }; use colored::Colorize; +use polkadot_node_network_protocol::request_response::ReqProtocolNames; use polkadot_primitives::AuthorityDiscoveryId; use prometheus_endpoint::U64; use rand::{seq::SliceRandom, thread_rng}; @@ -311,6 +312,8 @@ pub struct NetworkEmulator { stats: Vec>, /// Each emulated peer is a validator. validator_authority_ids: HashMap, + /// Request protocol names + req_protocol_names: ReqProtocolNames, } impl NetworkEmulator { @@ -318,6 +321,7 @@ impl NetworkEmulator { config: &TestConfiguration, dependencies: &TestEnvironmentDependencies, authorities: &TestAuthorities, + req_protocol_names: ReqProtocolNames, ) -> Self { let n_peers = config.n_validators; gum::info!(target: LOG_TARGET, "{}",format!("Initializing emulation for a {} peer network.", n_peers).bright_blue()); @@ -355,7 +359,12 @@ impl NetworkEmulator { gum::info!(target: LOG_TARGET, "{}",format!("Network created, connected validator count {}", connected_count).bright_black()); - Self { peers, stats, validator_authority_ids: validator_authority_id_mapping } + Self { + peers, + stats, + validator_authority_ids: validator_authority_id_mapping, + req_protocol_names, + } } pub fn is_peer_connected(&self, peer: &AuthorityDiscoveryId) -> bool { @@ -428,6 +437,11 @@ impl NetworkEmulator { // Our node always is peer 0. self.peer_stats(0).inc_received(bytes); } + + // Get the request protocol names + pub fn req_protocol_names(&self) -> &ReqProtocolNames { + &self.req_protocol_names + } } use polkadot_node_subsystem_util::metrics::prometheus::{ diff --git a/prdoc/pr_2771.prdoc b/prdoc/pr_2771.prdoc new file mode 100644 index 000000000000..1b49162e4392 --- /dev/null +++ b/prdoc/pr_2771.prdoc @@ -0,0 +1,9 @@ +title: Add fallback request for req-response protocols + +doc: + - audience: Node Dev + description: | + Enable better req-response protocol versioning, by allowing for fallback requests on different protocols. + +crates: + - name: sc_network diff --git a/substrate/client/consensus/beefy/src/communication/request_response/outgoing_requests_engine.rs b/substrate/client/consensus/beefy/src/communication/request_response/outgoing_requests_engine.rs index ef462a54fca5..7121410ea109 100644 --- a/substrate/client/consensus/beefy/src/communication/request_response/outgoing_requests_engine.rs +++ b/substrate/client/consensus/beefy/src/communication/request_response/outgoing_requests_engine.rs @@ -43,7 +43,7 @@ use crate::{ }; /// Response type received from network. -type Response = Result, RequestFailure>; +type Response = Result<(Vec, ProtocolName), RequestFailure>; /// Used to receive a response from the network. type ResponseReceiver = oneshot::Receiver; @@ -125,6 +125,7 @@ impl OnDemandJustificationsEngine { peer, self.protocol_name.clone(), payload, + None, tx, IfDisconnected::ImmediateError, ); @@ -204,7 +205,7 @@ impl OnDemandJustificationsEngine { }, } }) - .and_then(|encoded| { + .and_then(|(encoded, _)| { decode_and_verify_finality_proof::( &encoded[..], req_info.block, diff --git a/substrate/client/network/src/behaviour.rs b/substrate/client/network/src/behaviour.rs index 745550412fc2..1f234683392f 100644 --- a/substrate/client/network/src/behaviour.rs +++ b/substrate/client/network/src/behaviour.rs @@ -231,13 +231,20 @@ impl Behaviour { pub fn send_request( &mut self, target: &PeerId, - protocol: &str, + protocol: ProtocolName, request: Vec, - pending_response: oneshot::Sender, RequestFailure>>, + fallback_request: Option<(Vec, ProtocolName)>, + pending_response: oneshot::Sender, ProtocolName), RequestFailure>>, connect: IfDisconnected, ) { - self.request_responses - .send_request(target, protocol, request, pending_response, connect) + self.request_responses.send_request( + target, + protocol, + request, + fallback_request, + pending_response, + connect, + ) } /// Returns a shared reference to the user protocol. diff --git a/substrate/client/network/src/request_responses.rs b/substrate/client/network/src/request_responses.rs index 5af072aaddc6..0cd1cf06bb33 100644 --- a/substrate/client/network/src/request_responses.rs +++ b/substrate/client/network/src/request_responses.rs @@ -56,6 +56,7 @@ use libp2p::{ use std::{ collections::{hash_map::Entry, HashMap}, io, iter, + ops::Deref, pin::Pin, task::{Context, Poll}, time::{Duration, Instant}, @@ -172,6 +173,13 @@ pub struct OutgoingResponse { pub sent_feedback: Option>, } +/// Information stored about a pending request. +struct PendingRequest { + started_at: Instant, + response_tx: oneshot::Sender, ProtocolName), RequestFailure>>, + fallback_request: Option<(Vec, ProtocolName)>, +} + /// When sending a request, what to do on a disconnected recipient. #[derive(Debug, Copy, Clone, PartialEq, Eq, Hash)] pub enum IfDisconnected { @@ -264,8 +272,7 @@ pub struct RequestResponsesBehaviour { >, /// Pending requests, passed down to a request-response [`Behaviour`], awaiting a reply. - pending_requests: - HashMap, RequestFailure>>)>, + pending_requests: HashMap, /// Whenever an incoming request arrives, a `Future` is added to this list and will yield the /// start time and the response to send back to the remote. @@ -348,29 +355,25 @@ impl RequestResponsesBehaviour { pub fn send_request( &mut self, target: &PeerId, - protocol_name: &str, + protocol_name: ProtocolName, request: Vec, - pending_response: oneshot::Sender, RequestFailure>>, + fallback_request: Option<(Vec, ProtocolName)>, + pending_response: oneshot::Sender, ProtocolName), RequestFailure>>, connect: IfDisconnected, ) { log::trace!(target: "sub-libp2p", "send request to {target} ({protocol_name:?}), {} bytes", request.len()); - if let Some((protocol, _)) = self.protocols.get_mut(protocol_name) { - if protocol.is_connected(target) || connect.should_connect() { - let request_id = protocol.send_request(target, request); - let prev_req_id = self.pending_requests.insert( - (protocol_name.to_string().into(), request_id).into(), - (Instant::now(), pending_response), - ); - debug_assert!(prev_req_id.is_none(), "Expect request id to be unique."); - } else if pending_response.send(Err(RequestFailure::NotConnected)).is_err() { - log::debug!( - target: "sub-libp2p", - "Not connected to peer {:?}. At the same time local \ - node is no longer interested in the result.", - target, - ); - } + if let Some((protocol, _)) = self.protocols.get_mut(protocol_name.deref()) { + Self::send_request_inner( + protocol, + &mut self.pending_requests, + target, + protocol_name, + request, + fallback_request, + pending_response, + connect, + ) } else if pending_response.send(Err(RequestFailure::UnknownProtocol)).is_err() { log::debug!( target: "sub-libp2p", @@ -380,6 +383,37 @@ impl RequestResponsesBehaviour { ); } } + + fn send_request_inner( + behaviour: &mut Behaviour, + pending_requests: &mut HashMap, + target: &PeerId, + protocol_name: ProtocolName, + request: Vec, + fallback_request: Option<(Vec, ProtocolName)>, + pending_response: oneshot::Sender, ProtocolName), RequestFailure>>, + connect: IfDisconnected, + ) { + if behaviour.is_connected(target) || connect.should_connect() { + let request_id = behaviour.send_request(target, request); + let prev_req_id = pending_requests.insert( + (protocol_name.to_string().into(), request_id).into(), + PendingRequest { + started_at: Instant::now(), + response_tx: pending_response, + fallback_request, + }, + ); + debug_assert!(prev_req_id.is_none(), "Expect request id to be unique."); + } else if pending_response.send(Err(RequestFailure::NotConnected)).is_err() { + log::debug!( + target: "sub-libp2p", + "Not connected to peer {:?}. At the same time local \ + node is no longer interested in the result.", + target, + ); + } + } } impl NetworkBehaviour for RequestResponsesBehaviour { @@ -596,8 +630,10 @@ impl NetworkBehaviour for RequestResponsesBehaviour { } } + let mut fallback_requests = vec![]; + // Poll request-responses protocols. - for (protocol, (behaviour, resp_builder)) in &mut self.protocols { + for (protocol, (ref mut behaviour, ref mut resp_builder)) in &mut self.protocols { 'poll_protocol: while let Poll::Ready(ev) = behaviour.poll(cx, params) { let ev = match ev { // Main events we are interested in. @@ -698,17 +734,21 @@ impl NetworkBehaviour for RequestResponsesBehaviour { .pending_requests .remove(&(protocol.clone(), request_id).into()) { - Some((started, pending_response)) => { + Some(PendingRequest { started_at, response_tx, .. }) => { log::trace!( target: "sub-libp2p", "received response from {peer} ({protocol:?}), {} bytes", response.as_ref().map_or(0usize, |response| response.len()), ); - let delivered = pending_response - .send(response.map_err(|()| RequestFailure::Refused)) + let delivered = response_tx + .send( + response + .map_err(|()| RequestFailure::Refused) + .map(|resp| (resp, protocol.clone())), + ) .map_err(|_| RequestFailure::Obsolete); - (started, delivered) + (started_at, delivered) }, None => { log::warn!( @@ -742,8 +782,34 @@ impl NetworkBehaviour for RequestResponsesBehaviour { .pending_requests .remove(&(protocol.clone(), request_id).into()) { - Some((started, pending_response)) => { - if pending_response + Some(PendingRequest { + started_at, + response_tx, + fallback_request, + }) => { + // Try using the fallback request if the protocol was not + // supported. + if let OutboundFailure::UnsupportedProtocols = error { + if let Some((fallback_request, fallback_protocol)) = + fallback_request + { + log::trace!( + target: "sub-libp2p", + "Request with id {:?} failed. Trying the fallback protocol. {}", + request_id, + fallback_protocol.deref() + ); + fallback_requests.push(( + peer, + fallback_protocol, + fallback_request, + response_tx, + )); + continue + } + } + + if response_tx .send(Err(RequestFailure::Network(error.clone()))) .is_err() { @@ -754,7 +820,7 @@ impl NetworkBehaviour for RequestResponsesBehaviour { request_id, ); } - started + started_at }, None => { log::warn!( @@ -825,6 +891,25 @@ impl NetworkBehaviour for RequestResponsesBehaviour { } } + // Send out fallback requests. + for (peer, protocol, request, pending_response) in fallback_requests.drain(..) { + if let Some((behaviour, _)) = self.protocols.get_mut(&protocol) { + Self::send_request_inner( + behaviour, + &mut self.pending_requests, + &peer, + protocol, + request, + None, + pending_response, + // We can error if not connected because the + // previous attempt would have tried to establish a + // connection already or errored and we wouldn't have gotten here. + IfDisconnected::ImmediateError, + ); + } + } + break Poll::Pending } } @@ -976,6 +1061,7 @@ mod tests { use super::*; use crate::mock::MockPeerStore; + use assert_matches::assert_matches; use futures::{channel::oneshot, executor::LocalPool, task::Spawn}; use libp2p::{ core::{ @@ -1025,7 +1111,7 @@ mod tests { #[test] fn basic_request_response_works() { - let protocol_name = "/test/req-resp/1"; + let protocol_name = ProtocolName::from("/test/req-resp/1"); let mut pool = LocalPool::new(); // Build swarms whose behaviour is [`RequestResponsesBehaviour`]. @@ -1053,7 +1139,7 @@ mod tests { .unwrap(); let protocol_config = ProtocolConfig { - name: From::from(protocol_name), + name: protocol_name.clone(), fallback_names: Vec::new(), max_request_size: 1024, max_response_size: 1024 * 1024, @@ -1102,8 +1188,9 @@ mod tests { let (sender, receiver) = oneshot::channel(); swarm.behaviour_mut().send_request( &peer_id, - protocol_name, + protocol_name.clone(), b"this is a request".to_vec(), + None, sender, IfDisconnected::ImmediateError, ); @@ -1118,13 +1205,16 @@ mod tests { } } - assert_eq!(response_receiver.unwrap().await.unwrap().unwrap(), b"this is a response"); + assert_eq!( + response_receiver.unwrap().await.unwrap().unwrap(), + (b"this is a response".to_vec(), protocol_name) + ); }); } #[test] fn max_response_size_exceeded() { - let protocol_name = "/test/req-resp/1"; + let protocol_name = ProtocolName::from("/test/req-resp/1"); let mut pool = LocalPool::new(); // Build swarms whose behaviour is [`RequestResponsesBehaviour`]. @@ -1150,7 +1240,7 @@ mod tests { .unwrap(); let protocol_config = ProtocolConfig { - name: From::from(protocol_name), + name: protocol_name.clone(), fallback_names: Vec::new(), max_request_size: 1024, max_response_size: 8, // <-- important for the test @@ -1201,8 +1291,9 @@ mod tests { let (sender, receiver) = oneshot::channel(); swarm.behaviour_mut().send_request( &peer_id, - protocol_name, + protocol_name.clone(), b"this is a request".to_vec(), + None, sender, IfDisconnected::ImmediateError, ); @@ -1236,14 +1327,14 @@ mod tests { /// See [`ProtocolRequestId`] for additional information. #[test] fn request_id_collision() { - let protocol_name_1 = "/test/req-resp-1/1"; - let protocol_name_2 = "/test/req-resp-2/1"; + let protocol_name_1 = ProtocolName::from("/test/req-resp-1/1"); + let protocol_name_2 = ProtocolName::from("/test/req-resp-2/1"); let mut pool = LocalPool::new(); let mut swarm_1 = { let protocol_configs = vec![ ProtocolConfig { - name: From::from(protocol_name_1), + name: protocol_name_1.clone(), fallback_names: Vec::new(), max_request_size: 1024, max_response_size: 1024 * 1024, @@ -1251,7 +1342,7 @@ mod tests { inbound_queue: None, }, ProtocolConfig { - name: From::from(protocol_name_2), + name: protocol_name_2.clone(), fallback_names: Vec::new(), max_request_size: 1024, max_response_size: 1024 * 1024, @@ -1269,7 +1360,7 @@ mod tests { let protocol_configs = vec![ ProtocolConfig { - name: From::from(protocol_name_1), + name: protocol_name_1.clone(), fallback_names: Vec::new(), max_request_size: 1024, max_response_size: 1024 * 1024, @@ -1277,7 +1368,7 @@ mod tests { inbound_queue: Some(tx_1), }, ProtocolConfig { - name: From::from(protocol_name_2), + name: protocol_name_2.clone(), fallback_names: Vec::new(), max_request_size: 1024, max_response_size: 1024 * 1024, @@ -1359,15 +1450,17 @@ mod tests { let (sender_2, receiver_2) = oneshot::channel(); swarm_1.behaviour_mut().send_request( &peer_id, - protocol_name_1, + protocol_name_1.clone(), b"this is a request".to_vec(), + None, sender_1, IfDisconnected::ImmediateError, ); swarm_1.behaviour_mut().send_request( &peer_id, - protocol_name_2, + protocol_name_2.clone(), b"this is a request".to_vec(), + None, sender_2, IfDisconnected::ImmediateError, ); @@ -1385,8 +1478,239 @@ mod tests { } } let (response_receiver_1, response_receiver_2) = response_receivers.unwrap(); - assert_eq!(response_receiver_1.await.unwrap().unwrap(), b"this is a response"); - assert_eq!(response_receiver_2.await.unwrap().unwrap(), b"this is a response"); + assert_eq!( + response_receiver_1.await.unwrap().unwrap(), + (b"this is a response".to_vec(), protocol_name_1) + ); + assert_eq!( + response_receiver_2.await.unwrap().unwrap(), + (b"this is a response".to_vec(), protocol_name_2) + ); + }); + } + + #[test] + fn request_fallback() { + let protocol_name_1 = ProtocolName::from("/test/req-resp/2"); + let protocol_name_1_fallback = ProtocolName::from("/test/req-resp/1"); + let protocol_name_2 = ProtocolName::from("/test/another"); + let mut pool = LocalPool::new(); + + let protocol_config_1 = ProtocolConfig { + name: protocol_name_1.clone(), + fallback_names: Vec::new(), + max_request_size: 1024, + max_response_size: 1024 * 1024, + request_timeout: Duration::from_secs(30), + inbound_queue: None, + }; + let protocol_config_1_fallback = ProtocolConfig { + name: protocol_name_1_fallback.clone(), + fallback_names: Vec::new(), + max_request_size: 1024, + max_response_size: 1024 * 1024, + request_timeout: Duration::from_secs(30), + inbound_queue: None, + }; + let protocol_config_2 = ProtocolConfig { + name: protocol_name_2.clone(), + fallback_names: Vec::new(), + max_request_size: 1024, + max_response_size: 1024 * 1024, + request_timeout: Duration::from_secs(30), + inbound_queue: None, + }; + + // This swarm only speaks protocol_name_1_fallback and protocol_name_2. + // It only responds to requests. + let mut older_swarm = { + let (tx_1, mut rx_1) = async_channel::bounded::(64); + let (tx_2, mut rx_2) = async_channel::bounded::(64); + let mut protocol_config_1_fallback = protocol_config_1_fallback.clone(); + protocol_config_1_fallback.inbound_queue = Some(tx_1); + + let mut protocol_config_2 = protocol_config_2.clone(); + protocol_config_2.inbound_queue = Some(tx_2); + + pool.spawner() + .spawn_obj( + async move { + for _ in 0..2 { + if let Some(rq) = rx_1.next().await { + let (fb_tx, fb_rx) = oneshot::channel(); + assert_eq!(rq.payload, b"request on protocol /test/req-resp/1"); + let _ = rq.pending_response.send(super::OutgoingResponse { + result: Ok( + b"this is a response on protocol /test/req-resp/1".to_vec() + ), + reputation_changes: Vec::new(), + sent_feedback: Some(fb_tx), + }); + fb_rx.await.unwrap(); + } + } + + if let Some(rq) = rx_2.next().await { + let (fb_tx, fb_rx) = oneshot::channel(); + assert_eq!(rq.payload, b"request on protocol /test/other"); + let _ = rq.pending_response.send(super::OutgoingResponse { + result: Ok(b"this is a response on protocol /test/other".to_vec()), + reputation_changes: Vec::new(), + sent_feedback: Some(fb_tx), + }); + fb_rx.await.unwrap(); + } + } + .boxed() + .into(), + ) + .unwrap(); + + build_swarm(vec![protocol_config_1_fallback, protocol_config_2].into_iter()) + }; + + // This swarm speaks all protocols. + let mut new_swarm = build_swarm( + vec![ + protocol_config_1.clone(), + protocol_config_1_fallback.clone(), + protocol_config_2.clone(), + ] + .into_iter(), + ); + + { + let dial_addr = older_swarm.1.clone(); + Swarm::dial(&mut new_swarm.0, dial_addr).unwrap(); + } + + // Running `older_swarm`` in the background. + pool.spawner() + .spawn_obj({ + async move { + loop { + _ = older_swarm.0.select_next_some().await; + } + } + .boxed() + .into() + }) + .unwrap(); + + // Run the newer swarm. Attempt to make requests on all protocols. + let (mut swarm, _) = new_swarm; + let mut older_peer_id = None; + + pool.run_until(async move { + let mut response_receiver = None; + // Try the new protocol with a fallback. + loop { + match swarm.select_next_some().await { + SwarmEvent::ConnectionEstablished { peer_id, .. } => { + older_peer_id = Some(peer_id); + let (sender, receiver) = oneshot::channel(); + swarm.behaviour_mut().send_request( + &peer_id, + protocol_name_1.clone(), + b"request on protocol /test/req-resp/2".to_vec(), + Some(( + b"request on protocol /test/req-resp/1".to_vec(), + protocol_config_1_fallback.name.clone(), + )), + sender, + IfDisconnected::ImmediateError, + ); + response_receiver = Some(receiver); + }, + SwarmEvent::Behaviour(Event::RequestFinished { result, .. }) => { + result.unwrap(); + break + }, + _ => {}, + } + } + assert_eq!( + response_receiver.unwrap().await.unwrap().unwrap(), + ( + b"this is a response on protocol /test/req-resp/1".to_vec(), + protocol_name_1_fallback.clone() + ) + ); + // Try the old protocol with a useless fallback. + let (sender, response_receiver) = oneshot::channel(); + swarm.behaviour_mut().send_request( + older_peer_id.as_ref().unwrap(), + protocol_name_1_fallback.clone(), + b"request on protocol /test/req-resp/1".to_vec(), + Some(( + b"dummy request, will fail if processed".to_vec(), + protocol_config_1_fallback.name.clone(), + )), + sender, + IfDisconnected::ImmediateError, + ); + loop { + match swarm.select_next_some().await { + SwarmEvent::Behaviour(Event::RequestFinished { result, .. }) => { + result.unwrap(); + break + }, + _ => {}, + } + } + assert_eq!( + response_receiver.await.unwrap().unwrap(), + ( + b"this is a response on protocol /test/req-resp/1".to_vec(), + protocol_name_1_fallback.clone() + ) + ); + // Try the new protocol with no fallback. Should fail. + let (sender, response_receiver) = oneshot::channel(); + swarm.behaviour_mut().send_request( + older_peer_id.as_ref().unwrap(), + protocol_name_1.clone(), + b"request on protocol /test/req-resp-2".to_vec(), + None, + sender, + IfDisconnected::ImmediateError, + ); + loop { + match swarm.select_next_some().await { + SwarmEvent::Behaviour(Event::RequestFinished { result, .. }) => { + assert_matches!( + result.unwrap_err(), + RequestFailure::Network(OutboundFailure::UnsupportedProtocols) + ); + break + }, + _ => {}, + } + } + assert!(response_receiver.await.unwrap().is_err()); + // Try the other protocol with no fallback. + let (sender, response_receiver) = oneshot::channel(); + swarm.behaviour_mut().send_request( + older_peer_id.as_ref().unwrap(), + protocol_name_2.clone(), + b"request on protocol /test/other".to_vec(), + None, + sender, + IfDisconnected::ImmediateError, + ); + loop { + match swarm.select_next_some().await { + SwarmEvent::Behaviour(Event::RequestFinished { result, .. }) => { + result.unwrap(); + break + }, + _ => {}, + } + } + assert_eq!( + response_receiver.await.unwrap().unwrap(), + (b"this is a response on protocol /test/other".to_vec(), protocol_name_2.clone()) + ); }); } } diff --git a/substrate/client/network/src/service.rs b/substrate/client/network/src/service.rs index 06db23844d0d..47e23337633b 100644 --- a/substrate/client/network/src/service.rs +++ b/substrate/client/network/src/service.rs @@ -1048,11 +1048,12 @@ where target: PeerId, protocol: ProtocolName, request: Vec, + fallback_request: Option<(Vec, ProtocolName)>, connect: IfDisconnected, - ) -> Result, RequestFailure> { + ) -> Result<(Vec, ProtocolName), RequestFailure> { let (tx, rx) = oneshot::channel(); - self.start_request(target, protocol, request, tx, connect); + self.start_request(target, protocol, request, fallback_request, tx, connect); match rx.await { Ok(v) => v, @@ -1068,13 +1069,15 @@ where target: PeerId, protocol: ProtocolName, request: Vec, - tx: oneshot::Sender, RequestFailure>>, + fallback_request: Option<(Vec, ProtocolName)>, + tx: oneshot::Sender, ProtocolName), RequestFailure>>, connect: IfDisconnected, ) { let _ = self.to_worker.unbounded_send(ServiceToWorkerMsg::Request { target, protocol: protocol.into(), request, + fallback_request, pending_response: tx, connect, }); @@ -1160,7 +1163,8 @@ enum ServiceToWorkerMsg { target: PeerId, protocol: ProtocolName, request: Vec, - pending_response: oneshot::Sender, RequestFailure>>, + fallback_request: Option<(Vec, ProtocolName)>, + pending_response: oneshot::Sender, ProtocolName), RequestFailure>>, connect: IfDisconnected, }, NetworkStatus { @@ -1287,13 +1291,15 @@ where target, protocol, request, + fallback_request, pending_response, connect, } => { self.network_service.behaviour_mut().send_request( &target, - &protocol, + protocol, request, + fallback_request, pending_response, connect, ); diff --git a/substrate/client/network/src/service/traits.rs b/substrate/client/network/src/service/traits.rs index d4d4a05a86f1..74ddb986c247 100644 --- a/substrate/client/network/src/service/traits.rs +++ b/substrate/client/network/src/service/traits.rs @@ -551,8 +551,9 @@ pub trait NetworkRequest { target: PeerId, protocol: ProtocolName, request: Vec, + fallback_request: Option<(Vec, ProtocolName)>, connect: IfDisconnected, - ) -> Result, RequestFailure>; + ) -> Result<(Vec, ProtocolName), RequestFailure>; /// Variation of `request` which starts a request whose response is delivered on a provided /// channel. @@ -569,7 +570,8 @@ pub trait NetworkRequest { target: PeerId, protocol: ProtocolName, request: Vec, - tx: oneshot::Sender, RequestFailure>>, + fallback_request: Option<(Vec, ProtocolName)>, + tx: oneshot::Sender, ProtocolName), RequestFailure>>, connect: IfDisconnected, ); } @@ -585,13 +587,20 @@ where target: PeerId, protocol: ProtocolName, request: Vec, + fallback_request: Option<(Vec, ProtocolName)>, connect: IfDisconnected, - ) -> Pin, RequestFailure>> + Send + 'async_trait>> + ) -> Pin< + Box< + dyn Future, ProtocolName), RequestFailure>> + + Send + + 'async_trait, + >, + > where 'life0: 'async_trait, Self: 'async_trait, { - T::request(self, target, protocol, request, connect) + T::request(self, target, protocol, request, fallback_request, connect) } fn start_request( @@ -599,10 +608,11 @@ where target: PeerId, protocol: ProtocolName, request: Vec, - tx: oneshot::Sender, RequestFailure>>, + fallback_request: Option<(Vec, ProtocolName)>, + tx: oneshot::Sender, ProtocolName), RequestFailure>>, connect: IfDisconnected, ) { - T::start_request(self, target, protocol, request, tx, connect) + T::start_request(self, target, protocol, request, fallback_request, tx, connect) } } diff --git a/substrate/client/network/sync/src/block_relay_protocol.rs b/substrate/client/network/sync/src/block_relay_protocol.rs index 7a313458bf03..b4ef72a10c6b 100644 --- a/substrate/client/network/sync/src/block_relay_protocol.rs +++ b/substrate/client/network/sync/src/block_relay_protocol.rs @@ -18,7 +18,10 @@ use futures::channel::oneshot; use libp2p::PeerId; -use sc_network::request_responses::{ProtocolConfig, RequestFailure}; +use sc_network::{ + request_responses::{ProtocolConfig, RequestFailure}, + ProtocolName, +}; use sc_network_common::sync::message::{BlockData, BlockRequest}; use sp_runtime::traits::Block as BlockT; use std::sync::Arc; @@ -43,7 +46,7 @@ pub trait BlockDownloader: Send + Sync { &self, who: PeerId, request: BlockRequest, - ) -> Result, RequestFailure>, oneshot::Canceled>; + ) -> Result, ProtocolName), RequestFailure>, oneshot::Canceled>; /// Parses the protocol specific response to retrieve the block data. fn block_response_into_blocks( diff --git a/substrate/client/network/sync/src/block_request_handler.rs b/substrate/client/network/sync/src/block_request_handler.rs index f363dda3a2d1..f669a22cd2e9 100644 --- a/substrate/client/network/sync/src/block_request_handler.rs +++ b/substrate/client/network/sync/src/block_request_handler.rs @@ -570,7 +570,7 @@ impl BlockDownloader for FullBlockDownloader { &self, who: PeerId, request: BlockRequest, - ) -> Result, RequestFailure>, oneshot::Canceled> { + ) -> Result, ProtocolName), RequestFailure>, oneshot::Canceled> { // Build the request protobuf. let bytes = BlockRequestSchema { fields: request.fields.to_be_u32(), diff --git a/substrate/client/network/sync/src/engine.rs b/substrate/client/network/sync/src/engine.rs index d7b024cd801c..952300a14d89 100644 --- a/substrate/client/network/sync/src/engine.rs +++ b/substrate/client/network/sync/src/engine.rs @@ -1263,7 +1263,7 @@ where let ResponseEvent { peer_id, request, response } = response_event; match response { - Ok(Ok(resp)) => match request { + Ok(Ok((resp, _))) => match request { PeerRequest::Block(req) => { match self.block_downloader.block_response_into_blocks(&req, resp) { Ok(blocks) => { diff --git a/substrate/client/network/sync/src/mock.rs b/substrate/client/network/sync/src/mock.rs index 42220096e069..a4f5eb564c2c 100644 --- a/substrate/client/network/sync/src/mock.rs +++ b/substrate/client/network/sync/src/mock.rs @@ -22,7 +22,7 @@ use crate::block_relay_protocol::{BlockDownloader as BlockDownloaderT, BlockResp use futures::channel::oneshot; use libp2p::PeerId; -use sc_network::RequestFailure; +use sc_network::{ProtocolName, RequestFailure}; use sc_network_common::sync::message::{BlockData, BlockRequest}; use sp_runtime::traits::Block as BlockT; @@ -35,7 +35,7 @@ mockall::mock! { &self, who: PeerId, request: BlockRequest, - ) -> Result, RequestFailure>, oneshot::Canceled>; + ) -> Result, ProtocolName), RequestFailure>, oneshot::Canceled>; fn block_response_into_blocks( &self, request: &BlockRequest, diff --git a/substrate/client/network/sync/src/pending_responses.rs b/substrate/client/network/sync/src/pending_responses.rs index 55308dfc1ea9..e21a57632250 100644 --- a/substrate/client/network/sync/src/pending_responses.rs +++ b/substrate/client/network/sync/src/pending_responses.rs @@ -28,7 +28,7 @@ use futures::{ }; use libp2p::PeerId; use log::error; -use sc_network::request_responses::RequestFailure; +use sc_network::{request_responses::RequestFailure, types::ProtocolName}; use sp_runtime::traits::Block as BlockT; use std::task::{Context, Poll, Waker}; use tokio_stream::StreamMap; @@ -37,7 +37,7 @@ use tokio_stream::StreamMap; const LOG_TARGET: &'static str = "sync"; /// Response result. -type ResponseResult = Result, RequestFailure>, oneshot::Canceled>; +type ResponseResult = Result, ProtocolName), RequestFailure>, oneshot::Canceled>; /// A future yielding [`ResponseResult`]. type ResponseFuture = BoxFuture<'static, ResponseResult>; diff --git a/substrate/client/network/sync/src/service/mock.rs b/substrate/client/network/sync/src/service/mock.rs index 6e307d869844..420de8cd5fdc 100644 --- a/substrate/client/network/sync/src/service/mock.rs +++ b/substrate/client/network/sync/src/service/mock.rs @@ -117,14 +117,16 @@ mockall::mock! { target: PeerId, protocol: ProtocolName, request: Vec, + fallback_request: Option<(Vec, ProtocolName)>, connect: IfDisconnected, - ) -> Result, RequestFailure>; + ) -> Result<(Vec, ProtocolName), RequestFailure>; fn start_request( &self, target: PeerId, protocol: ProtocolName, request: Vec, - tx: oneshot::Sender, RequestFailure>>, + fallback_request: Option<(Vec, ProtocolName)>, + tx: oneshot::Sender, ProtocolName), RequestFailure>>, connect: IfDisconnected, ); } diff --git a/substrate/client/network/sync/src/service/network.rs b/substrate/client/network/sync/src/service/network.rs index 12a47d6a9b54..07f28519afb2 100644 --- a/substrate/client/network/sync/src/service/network.rs +++ b/substrate/client/network/sync/src/service/network.rs @@ -54,7 +54,7 @@ pub enum ToServiceCommand { PeerId, ProtocolName, Vec, - oneshot::Sender, RequestFailure>>, + oneshot::Sender, ProtocolName), RequestFailure>>, IfDisconnected, ), @@ -94,7 +94,7 @@ impl NetworkServiceHandle { who: PeerId, protocol: ProtocolName, request: Vec, - tx: oneshot::Sender, RequestFailure>>, + tx: oneshot::Sender, ProtocolName), RequestFailure>>, connect: IfDisconnected, ) { let _ = self @@ -134,7 +134,7 @@ impl NetworkServiceProvider { ToServiceCommand::ReportPeer(peer, reputation_change) => service.report_peer(peer, reputation_change), ToServiceCommand::StartRequest(peer, protocol, request, tx, connect) => - service.start_request(peer, protocol, request, tx, connect), + service.start_request(peer, protocol, request, None, tx, connect), ToServiceCommand::WriteNotification(peer, protocol, message) => service.write_notification(peer, protocol, message), ToServiceCommand::SetNotificationHandshake(protocol, handshake) =>