From c151baa4e0550b38f01815d5b68500009d57e6e7 Mon Sep 17 00:00:00 2001 From: Waclaw Banasik Date: Fri, 6 Dec 2024 12:33:37 +0000 Subject: [PATCH] fix(resharding) - fix the buffered receipts forwarding (#12561) The new logic introduced recently in #12538 relies on some assumptions that do not hold in the ShardLayoutV1. It's important to keep this shard layout working as this is what we currently have in mainnet. The issue was that in the V1 there may be shared shard ids between the parent shard ids that were split and the current shard ids. In this case the same buffer would be processed twice and crash the node. The fix is to make it depend on the protocol version. For old protocol versions we skip forwarding receipts from parents - there should be any before the new resharding. Starting from the new resharding it is ok to assume unique shard ids so the new logic works. For reason that I can't remember the `test_resharding_v3_load_mem_trie` test broke and I needed to make some adjustments to it. Maybe it's because the new base shard layout is non-trivial now. --- chain/chain/src/resharding/event_type.rs | 5 + core/primitives/src/shard_layout.rs | 19 +++- .../test_loop/tests/fix_min_stake_ratio.rs | 24 +++-- .../src/test_loop/tests/resharding_v3.rs | 98 +++++++++++++++---- runtime/runtime/src/congestion_control.rs | 31 +++--- 5 files changed, 131 insertions(+), 46 deletions(-) diff --git a/chain/chain/src/resharding/event_type.rs b/chain/chain/src/resharding/event_type.rs index 188880978c4..b62f7970ee5 100644 --- a/chain/chain/src/resharding/event_type.rs +++ b/chain/chain/src/resharding/event_type.rs @@ -79,6 +79,11 @@ impl ReshardingEventType { return log_and_error("can't perform two reshardings at the same time!"); } // Parent shard is no longer part of this shard layout. + // + // Please note the use of the next shard layout version. + // Technically speaking the current shard layout version + // should be used for the parent. However since + // ShardLayoutV2 the version is frozen so it is ok. let parent_shard = ShardUId::new(next_shard_layout.version(), *parent_id); let left_child_shard = ShardUId::from_shard_id_and_layout(children_ids[0], next_shard_layout); diff --git a/core/primitives/src/shard_layout.rs b/core/primitives/src/shard_layout.rs index 874c56f8525..64c97bbc63b 100644 --- a/core/primitives/src/shard_layout.rs +++ b/core/primitives/src/shard_layout.rs @@ -7,7 +7,7 @@ use near_schema_checker_lib::ProtocolSchema; use rand::rngs::StdRng; use rand::seq::SliceRandom; use rand::SeedableRng; -use std::collections::BTreeMap; +use std::collections::{BTreeMap, BTreeSet}; use std::{fmt, str}; /// This file implements two data structure `ShardLayout` and `ShardUId` @@ -784,6 +784,23 @@ impl ShardLayout { .ok_or(ShardLayoutError::InvalidShardIndexError { shard_index }), } } + + /// Returns all of the shards from the previous shard layout that were + /// split into multiple shards in this shard layout. + pub fn get_parent_shard_ids(&self) -> Result, ShardLayoutError> { + let mut parent_shard_ids = BTreeSet::new(); + for shard_id in self.shard_ids() { + let parent_shard_id = self.try_get_parent_shard_id(shard_id)?; + let Some(parent_shard_id) = parent_shard_id else { + continue; + }; + if parent_shard_id == shard_id { + continue; + } + parent_shard_ids.insert(parent_shard_id); + } + Ok(parent_shard_ids) + } } /// Maps an account to the shard that it belongs to given a shard_layout diff --git a/integration-tests/src/test_loop/tests/fix_min_stake_ratio.rs b/integration-tests/src/test_loop/tests/fix_min_stake_ratio.rs index 5a60cc21b5f..cf9b7921a3e 100644 --- a/integration-tests/src/test_loop/tests/fix_min_stake_ratio.rs +++ b/integration-tests/src/test_loop/tests/fix_min_stake_ratio.rs @@ -15,6 +15,7 @@ use near_primitives::num_rational::Rational32; use near_primitives::test_utils::create_user_test_signer; use near_primitives::transaction::SignedTransaction; use near_primitives::types::AccountId; +use near_primitives::upgrade_schedule::ProtocolUpgradeVotingSchedule; use near_primitives_core::version::ProtocolFeature; use std::string::ToString; use std::sync::atomic::{AtomicU64, Ordering}; @@ -24,7 +25,18 @@ use std::sync::atomic::{AtomicU64, Ordering}; #[test] fn slow_test_fix_min_stake_ratio() { init_test_logger(); - let builder = TestLoopBuilder::new(); + + // Take epoch configuration before the protocol upgrade, where minimum + // stake ratio was 1/6250. + let epoch_config_store = EpochConfigStore::for_chain_id("mainnet", None).unwrap(); + let target_protocol_version = ProtocolFeature::FixMinStakeRatio.protocol_version(); + let genesis_protocol_version = target_protocol_version - 1; + + // Immediately start voting for the new protocol version + let protocol_upgrade_schedule = + ProtocolUpgradeVotingSchedule::new_immediate(target_protocol_version); + + let builder = TestLoopBuilder::new().protocol_upgrade_schedule(protocol_upgrade_schedule); let initial_balance = 1_000_000 * ONE_NEAR; let epoch_length = 10; @@ -56,18 +68,16 @@ fn slow_test_fix_min_stake_ratio() { }, ]; - // Take epoch configuration before the protocol upgrade, where minimum - // stake ratio was 1/6250. - let epoch_config_store = EpochConfigStore::for_chain_id("mainnet", None).unwrap(); - let protocol_version = ProtocolFeature::FixMinStakeRatio.protocol_version() - 1; + let shard_layout = + epoch_config_store.get_config(genesis_protocol_version).as_ref().shard_layout.clone(); // Create chain with version before FixMinStakeRatio was enabled. // Check that the small validator is not included in the validator set. let mut genesis_builder = TestGenesisBuilder::new(); genesis_builder .genesis_time_from_clock(&builder.clock()) - .shard_layout(epoch_config_store.get_config(protocol_version).as_ref().shard_layout.clone()) - .protocol_version(protocol_version) + .shard_layout(shard_layout) + .protocol_version(genesis_protocol_version) .epoch_length(epoch_length) .validators_raw(validators, 1, 1, 2) // Disable validator rewards. diff --git a/integration-tests/src/test_loop/tests/resharding_v3.rs b/integration-tests/src/test_loop/tests/resharding_v3.rs index d3f05284f4d..11fb11b0afb 100644 --- a/integration-tests/src/test_loop/tests/resharding_v3.rs +++ b/integration-tests/src/test_loop/tests/resharding_v3.rs @@ -12,7 +12,7 @@ use near_primitives::epoch_manager::EpochConfigStore; use near_primitives::hash::CryptoHash; use near_primitives::shard_layout::{account_id_to_shard_uid, ShardLayout}; use near_primitives::state_record::StateRecord; -use near_primitives::types::{AccountId, BlockHeightDelta, EpochId, Gas, ShardId}; +use near_primitives::types::{AccountId, BlockHeightDelta, EpochId, Gas, NumShards, ShardId}; use near_primitives::version::{ProtocolFeature, PROTOCOL_VERSION}; use near_store::adapter::StoreAdapter; use near_store::db::refcount::decode_value_with_rc; @@ -590,24 +590,50 @@ fn shard_was_split(shard_layout: &ShardLayout, shard_id: ShardId) -> bool { parent != shard_id } -/// Asserts that for each child shard: -/// MemTrie, FlatState and DiskTrie all contain the same key-value pairs. -/// If `load_mem_tries_for_tracked_shards` is false, we only enforce memtries for split shards -/// Returns the ShardUIds that this client tracks and has sane memtries and flat storage for +/// Asserts that for each child shard, MemTrie, FlatState and DiskTrie all +/// contain the same key-value pairs. If `load_mem_tries_for_tracked_shards` is +/// false, we only enforce memtries for shards pending resharding in the old +/// layout and the shards thet were split in the new shard layout. +/// +/// Returns the ShardUIds that this client tracks and has sane memtries and flat +/// storage for +/// +/// The new num shards argument is a clumsy way to check if the head is before +/// or after resharding. fn assert_state_sanity( client: &Client, final_head: &Tip, load_mem_tries_for_tracked_shards: bool, + new_num_shards: NumShards, ) -> Vec { let shard_layout = client.epoch_manager.get_shard_layout(&final_head.epoch_id).unwrap(); + let is_resharded = shard_layout.num_shards() == new_num_shards; let mut checked_shards = Vec::new(); + let protocol_version = + client.epoch_manager.get_epoch_protocol_version(&final_head.epoch_id).unwrap(); + let shards_pending_resharding = client + .epoch_manager + .get_shard_uids_pending_resharding(protocol_version, PROTOCOL_VERSION) + .unwrap(); + for shard_uid in shard_layout.shard_uids() { - if !load_mem_tries_for_tracked_shards - && !shard_was_split(&shard_layout, shard_uid.shard_id()) - { - continue; + // TODO - the condition for checks is duplicated in the + // `get_epoch_check` method, refactor this. + if !load_mem_tries_for_tracked_shards { + // In the old layout do not enforce except for shards pending resharding. + if !is_resharded && !shards_pending_resharding.contains(&shard_uid) { + tracing::debug!(target: "test", ?shard_uid, "skipping shard not pending resharding"); + continue; + } + + // In the new layout do not enforce for shards that were not split. + if is_resharded && !shard_was_split(&shard_layout, shard_uid.shard_id()) { + tracing::debug!(target: "test", ?shard_uid, "skipping shard not split"); + continue; + } } + if !client_tracking_shard(client, shard_uid.shard_id(), &final_head.prev_block_hash) { continue; } @@ -696,22 +722,44 @@ impl TrieSanityCheck { } // If it's not already stored, initialize it with the expected ShardUIds for each account - fn get_epoch_check(&mut self, client: &Client, tip: &Tip) -> &mut EpochTrieCheck { + fn get_epoch_check( + &mut self, + client: &Client, + tip: &Tip, + new_num_shards: NumShards, + ) -> &mut EpochTrieCheck { + let protocol_version = + client.epoch_manager.get_epoch_protocol_version(&tip.epoch_id).unwrap(); + let shards_pending_resharding = client + .epoch_manager + .get_shard_uids_pending_resharding(protocol_version, PROTOCOL_VERSION) + .unwrap(); + let shard_layout = client.epoch_manager.get_shard_layout(&tip.epoch_id).unwrap(); + let is_resharded = shard_layout.num_shards() == new_num_shards; + match self.checks.entry(tip.epoch_id) { std::collections::hash_map::Entry::Occupied(e) => e.into_mut(), std::collections::hash_map::Entry::Vacant(e) => { - let shard_layout = client.epoch_manager.get_shard_layout(&tip.epoch_id).unwrap(); let shard_uids = shard_layout.shard_uids().collect_vec(); let mut check = HashMap::new(); for account_id in self.accounts.iter() { let tracked = shard_uids .iter() .filter_map(|uid| { - if !self.load_mem_tries_for_tracked_shards + if !is_resharded + && !self.load_mem_tries_for_tracked_shards + && !shards_pending_resharding.contains(uid) + { + return None; + } + + if is_resharded + && !self.load_mem_tries_for_tracked_shards && !shard_was_split(&shard_layout, uid.shard_id()) { return None; } + let cares = client.shard_tracker.care_about_shard( Some(account_id), &tip.prev_block_hash, @@ -733,7 +781,7 @@ impl TrieSanityCheck { } // Check trie sanity and keep track of which shards were succesfully fully checked - fn assert_state_sanity(&mut self, clients: &[&Client]) { + fn assert_state_sanity(&mut self, clients: &[&Client], new_num_shards: NumShards) { for client in clients { let signer = client.validator_signer.get(); let Some(account_id) = signer.as_ref().map(|s| s.validator_id()) else { @@ -753,9 +801,13 @@ impl TrieSanityCheck { if head.epoch_id != final_head.epoch_id { continue; } - let checked_shards = - assert_state_sanity(client, &final_head, self.load_mem_tries_for_tracked_shards); - let check = self.get_epoch_check(client, &head); + let checked_shards = assert_state_sanity( + client, + &final_head, + self.load_mem_tries_for_tracked_shards, + new_num_shards, + ); + let check = self.get_epoch_check(client, &head, new_num_shards); let check = check.get_mut(account_id).unwrap(); for shard_uid in checked_shards { check.insert(shard_uid, true); @@ -834,8 +886,13 @@ fn test_resharding_v3_base(params: TestReshardingParameters) { base_epoch_config.chunk_validator_only_kickout_threshold = 0; } + // Set the base shard layout to V1 with non-unique shard ids. This is what + // we have on mainnet before the ReshardingV3 release. + // TODO(resharding) test both V1->V2 and V2->V2 let boundary_accounts = vec!["account1".parse().unwrap(), "account3".parse().unwrap()]; - let base_shard_layout = ShardLayout::multi_shard_custom(boundary_accounts, 3); + let split_map = vec![vec![ShardId::new(0), ShardId::new(1), ShardId::new(2)]]; + #[allow(deprecated)] + let base_shard_layout = ShardLayout::v1(boundary_accounts, Some(split_map), 3); base_epoch_config.shard_layout = base_shard_layout.clone(); let new_boundary_account = "account6".parse().unwrap(); @@ -846,7 +903,7 @@ fn test_resharding_v3_base(params: TestReshardingParameters) { ShardLayout::derive_shard_layout(&base_shard_layout, new_boundary_account); tracing::info!(target: "test", ?base_shard_layout, new_shard_layout=?epoch_config.shard_layout, "shard layout"); - let expected_num_shards = epoch_config.shard_layout.shard_ids().count(); + let expected_num_shards = epoch_config.shard_layout.num_shards(); let epoch_config_store = EpochConfigStore::test(BTreeMap::from_iter(vec![ (base_protocol_version, Arc::new(base_epoch_config)), (base_protocol_version + 1, Arc::new(epoch_config)), @@ -945,9 +1002,8 @@ fn test_resharding_v3_base(params: TestReshardingParameters) { println!("State before resharding:"); print_and_assert_shard_accounts(&clients, &tip); } - trie_sanity_check.assert_state_sanity(&clients); + trie_sanity_check.assert_state_sanity(&clients, expected_num_shards); latest_block_height.set(tip.height); - println!("block: {} chunks: {:?}", tip.height, block_header.chunk_mask()); if params.all_chunks_expected && params.chunk_ranges_to_drop.is_empty() { assert!(block_header.chunk_mask().iter().all(|chunk_bit| *chunk_bit)); } @@ -960,7 +1016,7 @@ fn test_resharding_v3_base(params: TestReshardingParameters) { let prev_epoch_id = client.epoch_manager.get_prev_epoch_id_from_prev_block(&tip.prev_block_hash).unwrap(); let epoch_config = client.epoch_manager.get_epoch_config(&prev_epoch_id).unwrap(); - if epoch_config.shard_layout.shard_ids().count() != expected_num_shards { + if epoch_config.shard_layout.num_shards() != expected_num_shards { return false; } diff --git a/runtime/runtime/src/congestion_control.rs b/runtime/runtime/src/congestion_control.rs index 9f096c1c040..dede9b12b9c 100644 --- a/runtime/runtime/src/congestion_control.rs +++ b/runtime/runtime/src/congestion_control.rs @@ -234,24 +234,21 @@ impl ReceiptSinkV2 { let protocol_version = apply_state.current_protocol_version; let shard_layout = epoch_info_provider.shard_layout(&apply_state.epoch_id)?; - let shard_ids = if ProtocolFeature::SimpleNightshadeV4.enabled(protocol_version) { - shard_layout.shard_ids().collect_vec() - } else { - self.outgoing_limit.keys().copied().collect_vec() - }; - - let mut parent_shard_ids = BTreeSet::new(); - for &shard_id in &shard_ids { - let parent_shard_id = - shard_layout.try_get_parent_shard_id(shard_id).map_err(Into::::into)?; - let Some(parent_shard_id) = parent_shard_id else { - continue; + let (shard_ids, parent_shard_ids) = + if ProtocolFeature::SimpleNightshadeV4.enabled(protocol_version) { + ( + shard_layout.shard_ids().collect_vec(), + shard_layout.get_parent_shard_ids().map_err(Into::::into)?, + ) + } else { + (self.outgoing_limit.keys().copied().collect_vec(), BTreeSet::new()) }; - if parent_shard_id == shard_id { - continue; - } - parent_shard_ids.insert(parent_shard_id); - } + + // There mustn't be any shard ids in both the parents and the current + // shard ids. If this happens the same buffer will be processed twice. + debug_assert!( + parent_shard_ids.intersection(&shard_ids.clone().into_iter().collect()).count() == 0 + ); // First forward any receipts that may still be in the outgoing buffers // of the parent shards.