From 537b8d97cb71ea5e9c4648b8cdf4520ec5e22363 Mon Sep 17 00:00:00 2001 From: Andrea Date: Thu, 12 Dec 2024 12:09:38 +0100 Subject: [PATCH] refactor: Move out utilities from resharding_v3 test loop (#12601) The resharding V3 test loop became a beast of a file rather quickly. This PR simply moves some utility functions outside of the test into three other files, grouped by theme: - `receipts.rs` for all things about checking receipts existence - `sharding.rs` for everything related to shards and shard layout - `trie_sanity.rs` for the new TrieChecker implemented by @marcelo-gonzalez Hopefully this will make the method more reusable. `TestReshardingParameters` is not touched. It still needs improvements though. --- .../src/test_loop/tests/resharding_v3.rs | 656 +----------------- integration-tests/src/test_loop/utils/mod.rs | 21 +- .../src/test_loop/utils/receipts.rs | 157 +++++ .../src/test_loop/utils/sharding.rs | 120 ++++ .../src/test_loop/utils/transactions.rs | 18 + .../src/test_loop/utils/trie_sanity.rs | 369 ++++++++++ 6 files changed, 694 insertions(+), 647 deletions(-) create mode 100644 integration-tests/src/test_loop/utils/receipts.rs create mode 100644 integration-tests/src/test_loop/utils/sharding.rs create mode 100644 integration-tests/src/test_loop/utils/trie_sanity.rs diff --git a/integration-tests/src/test_loop/tests/resharding_v3.rs b/integration-tests/src/test_loop/tests/resharding_v3.rs index b787c501725..c6acf625705 100644 --- a/integration-tests/src/test_loop/tests/resharding_v3.rs +++ b/integration-tests/src/test_loop/tests/resharding_v3.rs @@ -1,132 +1,43 @@ -use borsh::BorshDeserialize; use itertools::Itertools; use near_async::test_loop::data::{TestLoopData, TestLoopDataHandle}; use near_async::time::Duration; -use near_chain::ChainStoreAccess; use near_chain_configs::test_genesis::{TestGenesisBuilder, ValidatorsSpec}; use near_chain_configs::DEFAULT_GC_NUM_EPOCHS_TO_KEEP; -use near_client::Client; use near_o11y::testonly::init_test_logger; -use near_primitives::block::Tip; use near_primitives::epoch_manager::EpochConfigStore; -use near_primitives::hash::CryptoHash; use near_primitives::shard_layout::{account_id_to_shard_uid, ShardLayout}; -use near_primitives::state_record::StateRecord; -use near_primitives::types::{AccountId, BlockHeightDelta, EpochId, Gas, NumShards, ShardId}; +use near_primitives::types::{AccountId, BlockHeightDelta, Gas, ShardId}; use near_primitives::version::{ProtocolFeature, PROTOCOL_VERSION}; -use near_store::adapter::StoreAdapter; -use near_store::db::refcount::decode_value_with_rc; -use near_store::{get, DBCol, ShardUId, Trie}; -use std::collections::{BTreeMap, HashMap, HashSet}; +use near_store::ShardUId; +use std::collections::{BTreeMap, HashMap}; use std::sync::Arc; use crate::test_loop::builder::TestLoopBuilder; use crate::test_loop::env::{TestData, TestLoopEnv}; +use crate::test_loop::utils::receipts::{ + check_receipts_presence_after_resharding_block, check_receipts_presence_at_resharding_block, + ReceiptKind, +}; +use crate::test_loop::utils::sharding::{ + next_block_has_new_shard_layout, print_and_assert_shard_accounts, +}; use crate::test_loop::utils::transactions::{ - get_shared_block_hash, get_smallest_height_head, run_tx, submit_tx, + get_shared_block_hash, get_smallest_height_head, run_tx, store_and_submit_tx, +}; +use crate::test_loop::utils::trie_sanity::{ + check_state_shard_uid_mapping_after_resharding, TrieSanityCheck, }; -use crate::test_loop::utils::{ONE_NEAR, TGAS}; +use crate::test_loop::utils::{LoopActionFn, ONE_NEAR, TGAS}; use assert_matches::assert_matches; use near_client::client_actor::ClientActorInner; use near_crypto::Signer; -use near_epoch_manager::EpochManagerAdapter; use near_parameters::{vm, RuntimeConfig, RuntimeConfigStore}; -use near_primitives::receipt::{ - BufferedReceiptIndices, DelayedReceiptIndices, PromiseYieldIndices, -}; -use near_primitives::state::FlatStateValue; use near_primitives::test_utils::create_user_test_signer; use near_primitives::transaction::SignedTransaction; -use near_primitives::trie_key::TrieKey; use near_primitives::views::FinalExecutionStatus; -use near_store::flat::FlatStorageStatus; use std::cell::Cell; use std::u64; -fn client_tracking_shard(client: &Client, shard_id: ShardId, parent_hash: &CryptoHash) -> bool { - let signer = client.validator_signer.get(); - let account_id = signer.as_ref().map(|s| s.validator_id()); - client.shard_tracker.care_about_shard(account_id, parent_hash, shard_id, true) -} - -fn get_client_tracking_shard<'a>( - clients: &'a [&Client], - tip: &Tip, - shard_id: ShardId, -) -> &'a Client { - for client in clients { - if client_tracking_shard(client, shard_id, &tip.prev_block_hash) { - return client; - } - } - panic!( - "get_client_tracking_shard() could not find client tracking shard {} at {} #{}", - shard_id, &tip.last_block_hash, tip.height - ); -} - -fn print_and_assert_shard_accounts(clients: &[&Client], tip: &Tip) { - let epoch_config = clients[0].epoch_manager.get_epoch_config(&tip.epoch_id).unwrap(); - for shard_uid in epoch_config.shard_layout.shard_uids() { - let client = get_client_tracking_shard(clients, tip, shard_uid.shard_id()); - let chunk_extra = client.chain.get_chunk_extra(&tip.prev_block_hash, &shard_uid).unwrap(); - let trie = client - .runtime_adapter - .get_trie_for_shard( - shard_uid.shard_id(), - &tip.prev_block_hash, - *chunk_extra.state_root(), - false, - ) - .unwrap(); - let mut shard_accounts = vec![]; - for item in trie.lock_for_iter().iter().unwrap() { - let (key, value) = item.unwrap(); - let state_record = StateRecord::from_raw_key_value(key, value); - if let Some(StateRecord::Account { account_id, .. }) = state_record { - shard_accounts.push(account_id.to_string()); - } - } - println!("accounts for shard {}: {:?}", shard_uid, shard_accounts); - assert!(!shard_accounts.is_empty()); - } -} - -/// Asserts that all parent shard State is accessible via parent and children shards. -fn check_state_shard_uid_mapping_after_resharding(client: &Client, parent_shard_uid: ShardUId) { - let tip = client.chain.head().unwrap(); - let epoch_id = tip.epoch_id; - let epoch_config = client.epoch_manager.get_epoch_config(&epoch_id).unwrap(); - let children_shard_uids = - epoch_config.shard_layout.get_children_shards_uids(parent_shard_uid.shard_id()).unwrap(); - assert_eq!(children_shard_uids.len(), 2); - - let store = client.chain.chain_store.store().trie_store(); - for kv in store.store().iter_raw_bytes(DBCol::State) { - let (key, value) = kv.unwrap(); - let shard_uid = ShardUId::try_from_slice(&key[0..8]).unwrap(); - // Just after resharding, no State data must be keyed using children ShardUIds. - assert!(!children_shard_uids.contains(&shard_uid)); - if shard_uid != parent_shard_uid { - continue; - } - let node_hash = CryptoHash::try_from_slice(&key[8..]).unwrap(); - let (value, _) = decode_value_with_rc(&value); - let parent_value = store.get(parent_shard_uid, &node_hash); - // Parent shard data must still be accessible using parent ShardUId. - assert_eq!(&parent_value.unwrap()[..], value.unwrap()); - // All parent shard data is available via both children shards. - for child_shard_uid in &children_shard_uids { - let child_value = store.get(*child_shard_uid, &node_hash); - assert_eq!(&child_value.unwrap()[..], value.unwrap()); - } - } -} - -/// Signature of functions callable from inside the inner loop of the resharding suite of tests. -type LoopActionFn = - Box)>; - #[derive(Default)] struct TestReshardingParameters { chunk_ranges_to_drop: HashMap>, @@ -327,146 +238,6 @@ fn fork_before_resharding_block(double_signing: bool) -> LoopActionFn { ) } -enum ReceiptKind { - Delayed, - Buffered, - PromiseYield, -} - -/// Checks that the shards containing `accounts` have a non empty set of receipts -/// of type `kind` at the resharding block. -fn check_receipts_presence_at_resharding_block( - accounts: Vec, - kind: ReceiptKind, -) -> LoopActionFn { - Box::new( - move |_: &[TestData], - test_loop_data: &mut TestLoopData, - client_handle: TestLoopDataHandle| { - let client_actor = test_loop_data.get_mut(&client_handle); - let tip = client_actor.client.chain.head().unwrap(); - - if !next_block_has_new_shard_layout(client_actor.client.epoch_manager.as_ref(), &tip) { - return; - } - - accounts.iter().for_each(|account| { - check_receipts_at_block(client_actor, &account, &kind, tip.clone()) - }); - }, - ) -} - -/// Checks that the shards containing `accounts` have a non empty set of receipts -/// of type `kind` at the block after the resharding block. -fn check_receipts_presence_after_resharding_block( - accounts: Vec, - kind: ReceiptKind, -) -> LoopActionFn { - Box::new( - move |_: &[TestData], - test_loop_data: &mut TestLoopData, - client_handle: TestLoopDataHandle| { - let client_actor = test_loop_data.get_mut(&client_handle); - let tip = client_actor.client.chain.head().unwrap(); - - if !this_block_has_new_shard_layout(client_actor.client.epoch_manager.as_ref(), &tip) { - return; - } - - accounts.iter().for_each(|account| { - check_receipts_at_block(client_actor, &account, &kind, tip.clone()) - }); - }, - ) -} - -fn check_receipts_at_block( - client_actor: &mut ClientActorInner, - account: &AccountId, - kind: &ReceiptKind, - tip: Tip, -) { - let epoch_manager = &client_actor.client.epoch_manager; - let shard_layout = epoch_manager.get_shard_layout(&tip.epoch_id).unwrap(); - let shard_id = epoch_manager.account_id_to_shard_id(&account, &tip.epoch_id).unwrap(); - let shard_uid = &ShardUId::from_shard_id_and_layout(shard_id, &shard_layout); - let congestion_info = &client_actor - .client - .chain - .chain_store() - .get_chunk_extra(&tip.last_block_hash, shard_uid) - .unwrap() - .congestion_info() - .unwrap(); - - let num_shards = shard_layout.shard_ids().count(); - let has_delayed = congestion_info.delayed_receipts_gas() != 0; - let has_buffered = congestion_info.buffered_receipts_gas() != 0; - tracing::info!(target: "test", height=tip.height, num_shards, ?shard_id, has_delayed, has_buffered, "checking receipts"); - - match kind { - ReceiptKind::Delayed => { - assert!(has_delayed); - check_delayed_receipts_exist_in_memtrie( - &client_actor.client, - &shard_uid, - &tip.prev_block_hash, - ); - } - ReceiptKind::Buffered => { - assert!(has_buffered); - check_buffered_receipts_exist_in_memtrie( - &client_actor.client, - &shard_uid, - &tip.prev_block_hash, - ); - } - ReceiptKind::PromiseYield => check_promise_yield_receipts_exist_in_memtrie( - &client_actor.client, - &shard_uid, - &tip.prev_block_hash, - ), - } -} - -/// Asserts that a non zero amount of delayed receipts exist in MemTrie for the given shard. -fn check_delayed_receipts_exist_in_memtrie( - client: &Client, - shard_uid: &ShardUId, - prev_block_hash: &CryptoHash, -) { - let memtrie = get_memtrie_for_shard(client, shard_uid, prev_block_hash); - let indices: DelayedReceiptIndices = - get(&memtrie, &TrieKey::DelayedReceiptIndices).unwrap().unwrap(); - assert_ne!(indices.len(), 0); -} - -/// Asserts that a non zero amount of buffered receipts exist in MemTrie for the given shard. -fn check_buffered_receipts_exist_in_memtrie( - client: &Client, - shard_uid: &ShardUId, - prev_block_hash: &CryptoHash, -) { - let memtrie = get_memtrie_for_shard(client, shard_uid, prev_block_hash); - let indices: BufferedReceiptIndices = - get(&memtrie, &TrieKey::BufferedReceiptIndices).unwrap().unwrap(); - // There should be at least one buffered receipt going to some other shard. It's not very precise but good enough. - assert_ne!(indices.shard_buffers.values().fold(0, |acc, buffer| acc + buffer.len()), 0); -} - -/// Asserts that a non zero amount of promise yield receipts exist in MemTrie for the given shard. -fn check_promise_yield_receipts_exist_in_memtrie( - client: &Client, - shard_uid: &ShardUId, - prev_block_hash: &CryptoHash, -) { - let memtrie = get_memtrie_for_shard(client, shard_uid, prev_block_hash); - let indices: PromiseYieldIndices = - get(&memtrie, &TrieKey::PromiseYieldIndices).unwrap().unwrap(); - assert_ne!(indices.len(), 0); -} - /// Returns a loop action that invokes a costly method from a contract /// `CALLS_PER_BLOCK_HEIGHT` times per block height. /// @@ -689,403 +460,6 @@ fn call_promise_yield( ) } -/// Stores a transaction hash into `txs` and submits the transaction. -fn store_and_submit_tx( - node_datas: &[TestData], - rpc_id: &AccountId, - txs: &Cell>, - signer_id: &AccountId, - receiver_id: &AccountId, - height: u64, - tx: SignedTransaction, -) { - let mut txs_vec = txs.take(); - tracing::debug!(target: "test", height, tx_hash=?tx.get_hash(), ?signer_id, ?receiver_id, "submitting transaction"); - txs_vec.push((tx.get_hash(), height)); - txs.set(txs_vec); - submit_tx(&node_datas, &rpc_id, tx); -} - -// We want to understand if the most recent block is a resharding block. To do -// this check if the latest block is an epoch start and compare the two epochs' -// shard layouts. -fn next_block_has_new_shard_layout(epoch_manager: &dyn EpochManagerAdapter, tip: &Tip) -> bool { - if !epoch_manager.is_next_block_epoch_start(&tip.last_block_hash).unwrap() { - return false; - } - - let this_epoch_id = tip.epoch_id; - let next_epoch_id = epoch_manager.get_next_epoch_id(&tip.last_block_hash).unwrap(); - - let this_shard_layout = epoch_manager.get_shard_layout(&this_epoch_id).unwrap(); - let next_shard_layout = epoch_manager.get_shard_layout(&next_epoch_id).unwrap(); - - this_shard_layout != next_shard_layout -} - -// We want to understand if the most recent block is the first block with the -// new shard layout. This is also the block immediately after the resharding -// block. To do this check if the latest block is an epoch start and compare the -// two epochs' shard layouts. -fn this_block_has_new_shard_layout(epoch_manager: &dyn EpochManagerAdapter, tip: &Tip) -> bool { - if !epoch_manager.is_next_block_epoch_start(&tip.prev_block_hash).unwrap() { - return false; - } - - let prev_epoch_id = epoch_manager.get_epoch_id(&tip.prev_block_hash).unwrap(); - let this_epoch_id = epoch_manager.get_epoch_id(&tip.last_block_hash).unwrap(); - - let prev_shard_layout = epoch_manager.get_shard_layout(&prev_epoch_id).unwrap(); - let this_shard_layout = epoch_manager.get_shard_layout(&this_epoch_id).unwrap(); - - this_shard_layout != prev_shard_layout -} - -fn get_memtrie_for_shard( - client: &Client, - shard_uid: &ShardUId, - prev_block_hash: &CryptoHash, -) -> Trie { - let state_root = - *client.chain.get_chunk_extra(prev_block_hash, shard_uid).unwrap().state_root(); - - // Here memtries will be used as long as client has memtries enabled. - let memtrie = client - .runtime_adapter - .get_trie_for_shard(shard_uid.shard_id(), prev_block_hash, state_root, false) - .unwrap(); - assert!(memtrie.has_memtries()); - memtrie -} - -fn assert_state_equal( - values1: &HashSet<(Vec, Vec)>, - values2: &HashSet<(Vec, Vec)>, - shard_uid: ShardUId, - cmp_msg: &str, -) { - let diff = values1.symmetric_difference(values2); - let mut has_diff = false; - for (key, value) in diff { - has_diff = true; - tracing::error!(target: "test", ?shard_uid, key=?key, ?value, "Difference in state between {}!", cmp_msg); - } - assert!(!has_diff, "{} state mismatch!", cmp_msg); -} - -fn shard_was_split(shard_layout: &ShardLayout, shard_id: ShardId) -> bool { - let Ok(parent) = shard_layout.get_parent_shard_id(shard_id) else { - return false; - }; - parent != shard_id -} - -/// Asserts that for each child shard, MemTrie, FlatState and DiskTrie all -/// contain the same key-value pairs. If `load_mem_tries_for_tracked_shards` is -/// false, we only enforce memtries for shards pending resharding in the old -/// layout and the shards thet were split in the new shard layout. -/// -/// Returns the ShardUIds that this client tracks and has sane memtries and flat -/// storage for -/// -/// The new num shards argument is a clumsy way to check if the head is before -/// or after resharding. -fn assert_state_sanity( - client: &Client, - final_head: &Tip, - load_mem_tries_for_tracked_shards: bool, - new_num_shards: NumShards, -) -> Vec { - let shard_layout = client.epoch_manager.get_shard_layout(&final_head.epoch_id).unwrap(); - let is_resharded = shard_layout.num_shards() == new_num_shards; - let mut checked_shards = Vec::new(); - - let protocol_version = - client.epoch_manager.get_epoch_protocol_version(&final_head.epoch_id).unwrap(); - let shards_pending_resharding = client - .epoch_manager - .get_shard_uids_pending_resharding(protocol_version, PROTOCOL_VERSION) - .unwrap(); - - for shard_uid in shard_layout.shard_uids() { - if !should_assert_state_sanity( - load_mem_tries_for_tracked_shards, - is_resharded, - &shards_pending_resharding, - &shard_layout, - &shard_uid, - ) { - continue; - } - - if !client_tracking_shard(client, shard_uid.shard_id(), &final_head.prev_block_hash) { - continue; - } - - let memtrie = get_memtrie_for_shard(client, &shard_uid, &final_head.prev_block_hash); - let memtrie_state = - memtrie.lock_for_iter().iter().unwrap().collect::, _>>().unwrap(); - - let state_root = *client - .chain - .get_chunk_extra(&final_head.prev_block_hash, &shard_uid) - .unwrap() - .state_root(); - - // To get a view on disk tries we can leverage the fact that get_view_trie_for_shard() never - // uses memtries. - let trie = client - .runtime_adapter - .get_view_trie_for_shard(shard_uid.shard_id(), &final_head.prev_block_hash, state_root) - .unwrap(); - assert!(!trie.has_memtries()); - let trie_state = - trie.lock_for_iter().iter().unwrap().collect::, _>>().unwrap(); - assert_state_equal(&memtrie_state, &trie_state, shard_uid, "memtrie and trie"); - - let flat_storage_manager = client.chain.runtime_adapter.get_flat_storage_manager(); - // FlatStorageChunkView::iter_range() used below to retrieve all key-value pairs in Flat - // Storage only looks at the data committed into the DB. For this reasons comparing Flat - // Storage and Memtries makes sense only if we can retrieve a view at the same height from - // both. - if let FlatStorageStatus::Ready(status) = - flat_storage_manager.get_flat_storage_status(shard_uid) - { - if status.flat_head.hash != final_head.prev_block_hash { - tracing::warn!(target: "test", "skipping flat storage - memtrie state check"); - continue; - } else { - tracing::debug!(target: "test", "checking flat storage - memtrie state"); - } - } else { - continue; - }; - let Some(flat_store_chunk_view) = - flat_storage_manager.chunk_view(shard_uid, final_head.last_block_hash) - else { - continue; - }; - let flat_store_state = flat_store_chunk_view - .iter_range(None, None) - .map_ok(|(key, value)| { - let value = match value { - FlatStateValue::Ref(value) => client - .chain - .chain_store() - .store() - .trie_store() - .get(shard_uid, &value.hash) - .unwrap() - .to_vec(), - FlatStateValue::Inlined(data) => data, - }; - (key, value) - }) - .collect::, _>>() - .unwrap(); - - assert_state_equal(&memtrie_state, &flat_store_state, shard_uid, "memtrie and flat store"); - checked_shards.push(shard_uid); - } - checked_shards -} - -fn should_assert_state_sanity( - load_mem_tries_for_tracked_shards: bool, - is_resharded: bool, - shards_pending_resharding: &HashSet, - shard_layout: &ShardLayout, - shard_uid: &ShardUId, -) -> bool { - // Always assert if the tracked shards are loaded into memory. - if load_mem_tries_for_tracked_shards { - return true; - } - - // In the old layout do not enforce except for shards pending resharding. - if !is_resharded && !shards_pending_resharding.contains(&shard_uid) { - return false; - } - - // In the new layout do not enforce for shards that were not split. - if is_resharded && !shard_was_split(shard_layout, shard_uid.shard_id()) { - return false; - } - - true -} - -// For each epoch, keep a map from AccountId to a map with keys equal to -// the set of shards that account tracks in that epoch, and bool values indicating -// whether the equality of flat storage and memtries has been checked for that shard -type EpochTrieCheck = HashMap>; - -/// Keeps track of the needed trie comparisons for each epoch. After we successfully call -/// assert_state_sanity() for an account ID, we mark those shards as checked for that epoch, -/// and then at the end of the test we check whether all expected shards for each account -/// were checked at least once in that epoch. We do this because assert_state_sanity() isn't -/// always able to perform the check if child shard flat storages are still being created, but -/// we want to make sure that it's always eventually checked by the end of the epoch -struct TrieSanityCheck { - accounts: Vec, - load_mem_tries_for_tracked_shards: bool, - checks: HashMap, -} - -impl TrieSanityCheck { - fn new(clients: &[&Client], load_mem_tries_for_tracked_shards: bool) -> Self { - let accounts = clients - .iter() - .filter_map(|c| { - let signer = c.validator_signer.get(); - signer.map(|s| s.validator_id().clone()) - }) - .collect(); - Self { accounts, load_mem_tries_for_tracked_shards, checks: HashMap::new() } - } - - // If it's not already stored, initialize it with the expected ShardUIds for each account - fn get_epoch_check( - &mut self, - client: &Client, - tip: &Tip, - new_num_shards: NumShards, - ) -> &mut EpochTrieCheck { - let protocol_version = - client.epoch_manager.get_epoch_protocol_version(&tip.epoch_id).unwrap(); - let shards_pending_resharding = client - .epoch_manager - .get_shard_uids_pending_resharding(protocol_version, PROTOCOL_VERSION) - .unwrap(); - let shard_layout = client.epoch_manager.get_shard_layout(&tip.epoch_id).unwrap(); - let is_resharded = shard_layout.num_shards() == new_num_shards; - - if self.checks.contains_key(&tip.epoch_id) { - return self.checks.get_mut(&tip.epoch_id).unwrap(); - } - - let mut check = HashMap::new(); - for account_id in self.accounts.iter() { - let check_shard_uids = self.get_epoch_check_for_account( - client, - tip, - is_resharded, - &shards_pending_resharding, - &shard_layout, - account_id, - ); - check.insert(account_id.clone(), check_shard_uids); - } - - self.checks.insert(tip.epoch_id, check); - self.checks.get_mut(&tip.epoch_id).unwrap() - } - - // Returns the expected shard uids for the given account. - fn get_epoch_check_for_account( - &self, - client: &Client, - tip: &Tip, - is_resharded: bool, - shards_pending_resharding: &HashSet, - shard_layout: &ShardLayout, - account_id: &AccountId, - ) -> HashMap { - let mut check_shard_uids = HashMap::new(); - for shard_uid in shard_layout.shard_uids() { - if !should_assert_state_sanity( - self.load_mem_tries_for_tracked_shards, - is_resharded, - shards_pending_resharding, - shard_layout, - &shard_uid, - ) { - continue; - } - - let cares = client.shard_tracker.care_about_shard( - Some(account_id), - &tip.prev_block_hash, - shard_uid.shard_id(), - false, - ); - if !cares { - continue; - } - check_shard_uids.insert(shard_uid, false); - } - check_shard_uids - } - - // Check trie sanity and keep track of which shards were succesfully fully checked - fn assert_state_sanity(&mut self, clients: &[&Client], new_num_shards: NumShards) { - for client in clients { - let signer = client.validator_signer.get(); - let Some(account_id) = signer.as_ref().map(|s| s.validator_id()) else { - // For now this is never relevant, since all of them have account IDs, but - // if this changes in the future, here we'll just skip those. - continue; - }; - let head = client.chain.head().unwrap(); - if head.epoch_id == EpochId::default() { - continue; - } - let final_head = client.chain.final_head().unwrap(); - // At the end of an epoch, we unload memtries for shards we'll no longer track. Also, - // the key/value equality comparison in assert_state_equal() is only guaranteed for - // final blocks. So these two together mean that we should only check this when the head - // and final head are in the same epoch. - if head.epoch_id != final_head.epoch_id { - continue; - } - let checked_shards = assert_state_sanity( - client, - &final_head, - self.load_mem_tries_for_tracked_shards, - new_num_shards, - ); - let check = self.get_epoch_check(client, &head, new_num_shards); - let check = check.get_mut(account_id).unwrap(); - for shard_uid in checked_shards { - check.insert(shard_uid, true); - } - } - } - - /// Look through all the epochs before the current one (because the current one will be early into the epoch, - /// and we won't have checked it yet) and make sure that for all accounts, all expected shards were checked at least once - fn check_epochs(&self, client: &Client) { - let tip = client.chain.head().unwrap(); - let mut block_info = client.epoch_manager.get_block_info(&tip.last_block_hash).unwrap(); - - loop { - let epoch_id = client - .epoch_manager - .get_prev_epoch_id_from_prev_block(block_info.prev_hash()) - .unwrap(); - if epoch_id == EpochId::default() { - break; - } - let check = self.checks.get(&epoch_id).unwrap_or_else(|| { - panic!("No trie comparison checks made for epoch {}", &epoch_id.0) - }); - for (account_id, checked_shards) in check.iter() { - for (shard_uid, checked) in checked_shards.iter() { - assert!( - checked, - "No trie comparison checks made for account {} epoch {} shard {}", - account_id, &epoch_id.0, shard_uid - ); - } - } - - block_info = - client.epoch_manager.get_block_info(block_info.epoch_first_block()).unwrap(); - block_info = client.epoch_manager.get_block_info(block_info.prev_hash()).unwrap(); - } - } -} - fn get_base_shard_layout(version: u64) -> ShardLayout { let boundary_accounts = vec!["account1".parse().unwrap(), "account3".parse().unwrap()]; match version { diff --git a/integration-tests/src/test_loop/utils/mod.rs b/integration-tests/src/test_loop/utils/mod.rs index fb7d6a2d6bd..492d644b3e7 100644 --- a/integration-tests/src/test_loop/utils/mod.rs +++ b/integration-tests/src/test_loop/utils/mod.rs @@ -1,10 +1,15 @@ -use super::env::TestLoopEnv; +use super::env::{TestData, TestLoopEnv}; +use near_async::test_loop::data::{TestLoopData, TestLoopDataHandle}; +use near_client::client_actor::ClientActorInner; -pub mod contract_distribution; -pub mod network; -pub mod setups; -pub mod transactions; -pub mod validators; +pub(crate) mod contract_distribution; +pub(crate) mod network; +pub(crate) mod receipts; +pub(crate) mod setups; +pub(crate) mod sharding; +pub(crate) mod transactions; +pub(crate) mod trie_sanity; +pub(crate) mod validators; pub(crate) const ONE_NEAR: u128 = 1_000_000_000_000_000_000_000_000; pub(crate) const TGAS: u64 = 1_000_000_000_000; @@ -15,3 +20,7 @@ pub(crate) fn get_head_height(env: &mut TestLoopEnv) -> u64 { let client = &env.test_loop.data.get(&client_handle).client; client.chain.head().unwrap().height } + +/// Signature of functions callable from inside the inner loop of a test loop test. +pub(crate) type LoopActionFn = + Box)>; diff --git a/integration-tests/src/test_loop/utils/receipts.rs b/integration-tests/src/test_loop/utils/receipts.rs new file mode 100644 index 00000000000..c49c921036b --- /dev/null +++ b/integration-tests/src/test_loop/utils/receipts.rs @@ -0,0 +1,157 @@ +use super::sharding::{next_block_has_new_shard_layout, this_block_has_new_shard_layout}; +use super::LoopActionFn; +use crate::test_loop::env::TestData; +use crate::test_loop::utils::sharding::get_memtrie_for_shard; +use near_async::test_loop::data::{TestLoopData, TestLoopDataHandle}; +use near_chain::types::Tip; +use near_chain::ChainStoreAccess; +use near_client::client_actor::ClientActorInner; +use near_client::Client; +use near_primitives::hash::CryptoHash; +use near_primitives::receipt::{ + BufferedReceiptIndices, DelayedReceiptIndices, PromiseYieldIndices, +}; +use near_primitives::trie_key::TrieKey; +use near_primitives::types::AccountId; +use near_store::{get, ShardUId}; + +pub enum ReceiptKind { + Delayed, + Buffered, + PromiseYield, +} + +/// Checks that the shards containing `accounts` have a non empty set of receipts +/// of type `kind` at the resharding block. +pub fn check_receipts_presence_at_resharding_block( + accounts: Vec, + kind: ReceiptKind, +) -> LoopActionFn { + Box::new( + move |_: &[TestData], + test_loop_data: &mut TestLoopData, + client_handle: TestLoopDataHandle| { + let client_actor = test_loop_data.get_mut(&client_handle); + let tip = client_actor.client.chain.head().unwrap(); + + if !next_block_has_new_shard_layout(client_actor.client.epoch_manager.as_ref(), &tip) { + return; + } + + accounts.iter().for_each(|account| { + check_receipts_at_block(client_actor, &account, &kind, tip.clone()) + }); + }, + ) +} + +/// Checks that the shards containing `accounts` have a non empty set of receipts +/// of type `kind` at the block after the resharding block. +pub fn check_receipts_presence_after_resharding_block( + accounts: Vec, + kind: ReceiptKind, +) -> LoopActionFn { + Box::new( + move |_: &[TestData], + test_loop_data: &mut TestLoopData, + client_handle: TestLoopDataHandle| { + let client_actor = test_loop_data.get_mut(&client_handle); + let tip = client_actor.client.chain.head().unwrap(); + + if !this_block_has_new_shard_layout(client_actor.client.epoch_manager.as_ref(), &tip) { + return; + } + + accounts.iter().for_each(|account| { + check_receipts_at_block(client_actor, &account, &kind, tip.clone()) + }); + }, + ) +} + +/// Asserts the presence of any receipt of type `kind` at the provided chain `tip`. +pub fn check_receipts_at_block( + client_actor: &mut ClientActorInner, + account: &AccountId, + kind: &ReceiptKind, + tip: Tip, +) { + let epoch_manager = &client_actor.client.epoch_manager; + let shard_layout = epoch_manager.get_shard_layout(&tip.epoch_id).unwrap(); + let shard_id = epoch_manager.account_id_to_shard_id(&account, &tip.epoch_id).unwrap(); + let shard_uid = &ShardUId::from_shard_id_and_layout(shard_id, &shard_layout); + let congestion_info = &client_actor + .client + .chain + .chain_store() + .get_chunk_extra(&tip.last_block_hash, shard_uid) + .unwrap() + .congestion_info() + .unwrap(); + + let num_shards = shard_layout.shard_ids().count(); + let has_delayed = congestion_info.delayed_receipts_gas() != 0; + let has_buffered = congestion_info.buffered_receipts_gas() != 0; + tracing::info!(target: "test", height=tip.height, num_shards, ?shard_id, has_delayed, has_buffered, "checking receipts"); + + match kind { + ReceiptKind::Delayed => { + assert!(has_delayed); + check_delayed_receipts_exist_in_memtrie( + &client_actor.client, + &shard_uid, + &tip.prev_block_hash, + ); + } + ReceiptKind::Buffered => { + assert!(has_buffered); + check_buffered_receipts_exist_in_memtrie( + &client_actor.client, + &shard_uid, + &tip.prev_block_hash, + ); + } + ReceiptKind::PromiseYield => check_promise_yield_receipts_exist_in_memtrie( + &client_actor.client, + &shard_uid, + &tip.prev_block_hash, + ), + } +} + +/// Asserts that a non zero amount of delayed receipts exist in MemTrie for the given shard. +fn check_delayed_receipts_exist_in_memtrie( + client: &Client, + shard_uid: &ShardUId, + prev_block_hash: &CryptoHash, +) { + let memtrie = get_memtrie_for_shard(client, shard_uid, prev_block_hash); + let indices: DelayedReceiptIndices = + get(&memtrie, &TrieKey::DelayedReceiptIndices).unwrap().unwrap(); + assert_ne!(indices.len(), 0); +} + +/// Asserts that a non zero amount of buffered receipts exist in MemTrie for the given shard. +fn check_buffered_receipts_exist_in_memtrie( + client: &Client, + shard_uid: &ShardUId, + prev_block_hash: &CryptoHash, +) { + let memtrie = get_memtrie_for_shard(client, shard_uid, prev_block_hash); + let indices: BufferedReceiptIndices = + get(&memtrie, &TrieKey::BufferedReceiptIndices).unwrap().unwrap(); + // There should be at least one buffered receipt going to some other shard. It's not very precise but good enough. + assert_ne!(indices.shard_buffers.values().fold(0, |acc, buffer| acc + buffer.len()), 0); +} + +/// Asserts that a non zero amount of promise yield receipts exist in MemTrie for the given shard. +fn check_promise_yield_receipts_exist_in_memtrie( + client: &Client, + shard_uid: &ShardUId, + prev_block_hash: &CryptoHash, +) { + let memtrie = get_memtrie_for_shard(client, shard_uid, prev_block_hash); + let indices: PromiseYieldIndices = + get(&memtrie, &TrieKey::PromiseYieldIndices).unwrap().unwrap(); + assert_ne!(indices.len(), 0); +} diff --git a/integration-tests/src/test_loop/utils/sharding.rs b/integration-tests/src/test_loop/utils/sharding.rs new file mode 100644 index 00000000000..8f377f122b1 --- /dev/null +++ b/integration-tests/src/test_loop/utils/sharding.rs @@ -0,0 +1,120 @@ +use near_chain::types::Tip; +use near_client::Client; +use near_epoch_manager::EpochManagerAdapter; +use near_primitives::hash::CryptoHash; +use near_primitives::shard_layout::ShardLayout; +use near_primitives::state_record::StateRecord; +use near_primitives::types::ShardId; +use near_store::{ShardUId, Trie}; + +/// Returns `true` if `client` is tracking the shard having the given `shard_id`. +pub fn client_tracking_shard(client: &Client, shard_id: ShardId, parent_hash: &CryptoHash) -> bool { + let signer = client.validator_signer.get(); + let account_id = signer.as_ref().map(|s| s.validator_id()); + client.shard_tracker.care_about_shard(account_id, parent_hash, shard_id, true) +} + +// Finds the client who tracks the shard with `shard_id` among the list of `clients`. +pub fn get_client_tracking_shard<'a>( + clients: &'a [&Client], + tip: &Tip, + shard_id: ShardId, +) -> &'a Client { + for client in clients { + if client_tracking_shard(client, shard_id, &tip.prev_block_hash) { + return client; + } + } + panic!( + "get_client_tracking_shard() could not find client tracking shard {} at {} #{}", + shard_id, &tip.last_block_hash, tip.height + ); +} + +/// Prints the accounts inside all shards and asserts that no shard is empty. +pub fn print_and_assert_shard_accounts(clients: &[&Client], tip: &Tip) { + let epoch_config = clients[0].epoch_manager.get_epoch_config(&tip.epoch_id).unwrap(); + for shard_uid in epoch_config.shard_layout.shard_uids() { + let client = get_client_tracking_shard(clients, tip, shard_uid.shard_id()); + let chunk_extra = client.chain.get_chunk_extra(&tip.prev_block_hash, &shard_uid).unwrap(); + let trie = client + .runtime_adapter + .get_trie_for_shard( + shard_uid.shard_id(), + &tip.prev_block_hash, + *chunk_extra.state_root(), + false, + ) + .unwrap(); + let mut shard_accounts = vec![]; + for item in trie.lock_for_iter().iter().unwrap() { + let (key, value) = item.unwrap(); + let state_record = StateRecord::from_raw_key_value(key, value); + if let Some(StateRecord::Account { account_id, .. }) = state_record { + shard_accounts.push(account_id.to_string()); + } + } + println!("accounts for shard {}: {:?}", shard_uid, shard_accounts); + assert!(!shard_accounts.is_empty()); + } +} + +/// Get the Memtrie of a shard at a certain block hash. +pub fn get_memtrie_for_shard( + client: &Client, + shard_uid: &ShardUId, + prev_block_hash: &CryptoHash, +) -> Trie { + let state_root = + *client.chain.get_chunk_extra(prev_block_hash, shard_uid).unwrap().state_root(); + + // Here memtries will be used as long as client has memtries enabled. + let memtrie = client + .runtime_adapter + .get_trie_for_shard(shard_uid.shard_id(), prev_block_hash, state_root, false) + .unwrap(); + assert!(memtrie.has_memtries()); + memtrie +} + +// We want to understand if the most recent block is the first block with the +// new shard layout. This is also the block immediately after the resharding +// block. To do this check if the latest block is an epoch start and compare the +// two epochs' shard layouts. +pub fn this_block_has_new_shard_layout(epoch_manager: &dyn EpochManagerAdapter, tip: &Tip) -> bool { + if !epoch_manager.is_next_block_epoch_start(&tip.prev_block_hash).unwrap() { + return false; + } + + let prev_epoch_id = epoch_manager.get_epoch_id(&tip.prev_block_hash).unwrap(); + let this_epoch_id = epoch_manager.get_epoch_id(&tip.last_block_hash).unwrap(); + + let prev_shard_layout = epoch_manager.get_shard_layout(&prev_epoch_id).unwrap(); + let this_shard_layout = epoch_manager.get_shard_layout(&this_epoch_id).unwrap(); + + this_shard_layout != prev_shard_layout +} + +// We want to understand if the most recent block is a resharding block. To do +// this check if the latest block is an epoch start and compare the two epochs' +// shard layouts. +pub fn next_block_has_new_shard_layout(epoch_manager: &dyn EpochManagerAdapter, tip: &Tip) -> bool { + if !epoch_manager.is_next_block_epoch_start(&tip.last_block_hash).unwrap() { + return false; + } + + let this_epoch_id = tip.epoch_id; + let next_epoch_id = epoch_manager.get_next_epoch_id(&tip.last_block_hash).unwrap(); + + let this_shard_layout = epoch_manager.get_shard_layout(&this_epoch_id).unwrap(); + let next_shard_layout = epoch_manager.get_shard_layout(&next_epoch_id).unwrap(); + + this_shard_layout != next_shard_layout +} + +pub fn shard_was_split(shard_layout: &ShardLayout, shard_id: ShardId) -> bool { + let Ok(parent) = shard_layout.get_parent_shard_id(shard_id) else { + return false; + }; + parent != shard_id +} diff --git a/integration-tests/src/test_loop/utils/transactions.rs b/integration-tests/src/test_loop/utils/transactions.rs index 4b9ef801680..90284d61550 100644 --- a/integration-tests/src/test_loop/utils/transactions.rs +++ b/integration-tests/src/test_loop/utils/transactions.rs @@ -26,6 +26,7 @@ use std::task::Poll; use super::{ONE_NEAR, TGAS}; use near_async::futures::FutureSpawnerExt; +use std::cell::Cell; /// See `execute_money_transfers`. Debug is implemented so .unwrap() can print /// the error. @@ -658,3 +659,20 @@ enum TxProcessingResult { Congested(InvalidTxError), Invalid(InvalidTxError), } + +/// Stores a transaction hash into a vector of `(transaction, block_height)` and then submits the transaction. +pub fn store_and_submit_tx( + node_datas: &[TestData], + rpc_id: &AccountId, + txs: &Cell>, + signer_id: &AccountId, + receiver_id: &AccountId, + height: u64, + tx: SignedTransaction, +) { + let mut txs_vec = txs.take(); + tracing::debug!(target: "test", height, tx_hash=?tx.get_hash(), ?signer_id, ?receiver_id, "submitting transaction"); + txs_vec.push((tx.get_hash(), height)); + txs.set(txs_vec); + submit_tx(&node_datas, &rpc_id, tx); +} diff --git a/integration-tests/src/test_loop/utils/trie_sanity.rs b/integration-tests/src/test_loop/utils/trie_sanity.rs new file mode 100644 index 00000000000..c3d5147742d --- /dev/null +++ b/integration-tests/src/test_loop/utils/trie_sanity.rs @@ -0,0 +1,369 @@ +use super::sharding::shard_was_split; +use crate::test_loop::utils::sharding::{client_tracking_shard, get_memtrie_for_shard}; +use borsh::BorshDeserialize; +use itertools::Itertools; +use near_chain::types::Tip; +use near_chain::ChainStoreAccess; +use near_client::Client; +use near_primitives::hash::CryptoHash; +use near_primitives::shard_layout::ShardLayout; +use near_primitives::state::FlatStateValue; +use near_primitives::types::{AccountId, EpochId, NumShards}; +use near_primitives::version::PROTOCOL_VERSION; +use near_store::adapter::StoreAdapter; +use near_store::db::refcount::decode_value_with_rc; +use near_store::flat::FlatStorageStatus; +use near_store::{DBCol, ShardUId}; +use std::collections::{HashMap, HashSet}; + +// For each epoch, keep a map from AccountId to a map with keys equal to +// the set of shards that account tracks in that epoch, and bool values indicating +// whether the equality of flat storage and memtries has been checked for that shard +type EpochTrieCheck = HashMap>; + +/// Keeps track of the needed trie comparisons for each epoch. After we successfully call +/// assert_state_sanity() for an account ID, we mark those shards as checked for that epoch, +/// and then at the end of the test we check whether all expected shards for each account +/// were checked at least once in that epoch. We do this because assert_state_sanity() isn't +/// always able to perform the check if child shard flat storages are still being created, but +/// we want to make sure that it's always eventually checked by the end of the epoch +pub struct TrieSanityCheck { + accounts: Vec, + load_mem_tries_for_tracked_shards: bool, + checks: HashMap, +} + +impl TrieSanityCheck { + pub fn new(clients: &[&Client], load_mem_tries_for_tracked_shards: bool) -> Self { + let accounts = clients + .iter() + .filter_map(|c| { + let signer = c.validator_signer.get(); + signer.map(|s| s.validator_id().clone()) + }) + .collect(); + Self { accounts, load_mem_tries_for_tracked_shards, checks: HashMap::new() } + } + + // If it's not already stored, initialize it with the expected ShardUIds for each account + fn get_epoch_check( + &mut self, + client: &Client, + tip: &Tip, + new_num_shards: NumShards, + ) -> &mut EpochTrieCheck { + let protocol_version = + client.epoch_manager.get_epoch_protocol_version(&tip.epoch_id).unwrap(); + let shards_pending_resharding = client + .epoch_manager + .get_shard_uids_pending_resharding(protocol_version, PROTOCOL_VERSION) + .unwrap(); + let shard_layout = client.epoch_manager.get_shard_layout(&tip.epoch_id).unwrap(); + let is_resharded = shard_layout.num_shards() == new_num_shards; + + if self.checks.contains_key(&tip.epoch_id) { + return self.checks.get_mut(&tip.epoch_id).unwrap(); + } + + let mut check = HashMap::new(); + for account_id in self.accounts.iter() { + let check_shard_uids = self.get_epoch_check_for_account( + client, + tip, + is_resharded, + &shards_pending_resharding, + &shard_layout, + account_id, + ); + check.insert(account_id.clone(), check_shard_uids); + } + + self.checks.insert(tip.epoch_id, check); + self.checks.get_mut(&tip.epoch_id).unwrap() + } + + // Returns the expected shard uids for the given account. + fn get_epoch_check_for_account( + &self, + client: &Client, + tip: &Tip, + is_resharded: bool, + shards_pending_resharding: &HashSet, + shard_layout: &ShardLayout, + account_id: &AccountId, + ) -> HashMap { + let mut check_shard_uids = HashMap::new(); + for shard_uid in shard_layout.shard_uids() { + if !should_assert_state_sanity( + self.load_mem_tries_for_tracked_shards, + is_resharded, + shards_pending_resharding, + shard_layout, + &shard_uid, + ) { + continue; + } + + let cares = client.shard_tracker.care_about_shard( + Some(account_id), + &tip.prev_block_hash, + shard_uid.shard_id(), + false, + ); + if !cares { + continue; + } + check_shard_uids.insert(shard_uid, false); + } + check_shard_uids + } + + // Check trie sanity and keep track of which shards were succesfully fully checked + pub fn assert_state_sanity(&mut self, clients: &[&Client], new_num_shards: NumShards) { + for client in clients { + let signer = client.validator_signer.get(); + let Some(account_id) = signer.as_ref().map(|s| s.validator_id()) else { + // For now this is never relevant, since all of them have account IDs, but + // if this changes in the future, here we'll just skip those. + continue; + }; + let head = client.chain.head().unwrap(); + if head.epoch_id == EpochId::default() { + continue; + } + let final_head = client.chain.final_head().unwrap(); + // At the end of an epoch, we unload memtries for shards we'll no longer track. Also, + // the key/value equality comparison in assert_state_equal() is only guaranteed for + // final blocks. So these two together mean that we should only check this when the head + // and final head are in the same epoch. + if head.epoch_id != final_head.epoch_id { + continue; + } + let checked_shards = assert_state_sanity( + client, + &final_head, + self.load_mem_tries_for_tracked_shards, + new_num_shards, + ); + let check = self.get_epoch_check(client, &head, new_num_shards); + let check = check.get_mut(account_id).unwrap(); + for shard_uid in checked_shards { + check.insert(shard_uid, true); + } + } + } + + /// Look through all the epochs before the current one (because the current one will be early into the epoch, + /// and we won't have checked it yet) and make sure that for all accounts, all expected shards were checked at least once + pub fn check_epochs(&self, client: &Client) { + let tip = client.chain.head().unwrap(); + let mut block_info = client.epoch_manager.get_block_info(&tip.last_block_hash).unwrap(); + + loop { + let epoch_id = client + .epoch_manager + .get_prev_epoch_id_from_prev_block(block_info.prev_hash()) + .unwrap(); + if epoch_id == EpochId::default() { + break; + } + let check = self.checks.get(&epoch_id).unwrap_or_else(|| { + panic!("No trie comparison checks made for epoch {}", &epoch_id.0) + }); + for (account_id, checked_shards) in check.iter() { + for (shard_uid, checked) in checked_shards.iter() { + assert!( + checked, + "No trie comparison checks made for account {} epoch {} shard {}", + account_id, &epoch_id.0, shard_uid + ); + } + } + + block_info = + client.epoch_manager.get_block_info(block_info.epoch_first_block()).unwrap(); + block_info = client.epoch_manager.get_block_info(block_info.prev_hash()).unwrap(); + } + } +} + +/// Asserts that for each child shard, MemTrie, FlatState and DiskTrie all +/// contain the same key-value pairs. If `load_mem_tries_for_tracked_shards` is +/// false, we only enforce memtries for shards pending resharding in the old +/// layout and the shards thet were split in the new shard layout. +/// +/// Returns the ShardUIds that this client tracks and has sane memtries and flat +/// storage for +/// +/// The new num shards argument is a clumsy way to check if the head is before +/// or after resharding. +fn assert_state_sanity( + client: &Client, + final_head: &Tip, + load_mem_tries_for_tracked_shards: bool, + new_num_shards: NumShards, +) -> Vec { + let shard_layout = client.epoch_manager.get_shard_layout(&final_head.epoch_id).unwrap(); + let is_resharded = shard_layout.num_shards() == new_num_shards; + let mut checked_shards = Vec::new(); + + let protocol_version = + client.epoch_manager.get_epoch_protocol_version(&final_head.epoch_id).unwrap(); + let shards_pending_resharding = client + .epoch_manager + .get_shard_uids_pending_resharding(protocol_version, PROTOCOL_VERSION) + .unwrap(); + + for shard_uid in shard_layout.shard_uids() { + if !should_assert_state_sanity( + load_mem_tries_for_tracked_shards, + is_resharded, + &shards_pending_resharding, + &shard_layout, + &shard_uid, + ) { + continue; + } + + if !client_tracking_shard(client, shard_uid.shard_id(), &final_head.prev_block_hash) { + continue; + } + + let memtrie = get_memtrie_for_shard(client, &shard_uid, &final_head.prev_block_hash); + let memtrie_state = + memtrie.lock_for_iter().iter().unwrap().collect::, _>>().unwrap(); + + let state_root = *client + .chain + .get_chunk_extra(&final_head.prev_block_hash, &shard_uid) + .unwrap() + .state_root(); + + // To get a view on disk tries we can leverage the fact that get_view_trie_for_shard() never + // uses memtries. + let trie = client + .runtime_adapter + .get_view_trie_for_shard(shard_uid.shard_id(), &final_head.prev_block_hash, state_root) + .unwrap(); + assert!(!trie.has_memtries()); + let trie_state = + trie.lock_for_iter().iter().unwrap().collect::, _>>().unwrap(); + assert_state_equal(&memtrie_state, &trie_state, shard_uid, "memtrie and trie"); + + let flat_storage_manager = client.chain.runtime_adapter.get_flat_storage_manager(); + // FlatStorageChunkView::iter_range() used below to retrieve all key-value pairs in Flat + // Storage only looks at the data committed into the DB. For this reasons comparing Flat + // Storage and Memtries makes sense only if we can retrieve a view at the same height from + // both. + if let FlatStorageStatus::Ready(status) = + flat_storage_manager.get_flat_storage_status(shard_uid) + { + if status.flat_head.hash != final_head.prev_block_hash { + tracing::warn!(target: "test", "skipping flat storage - memtrie state check"); + continue; + } else { + tracing::debug!(target: "test", "checking flat storage - memtrie state"); + } + } else { + continue; + }; + let Some(flat_store_chunk_view) = + flat_storage_manager.chunk_view(shard_uid, final_head.last_block_hash) + else { + continue; + }; + let flat_store_state = flat_store_chunk_view + .iter_range(None, None) + .map_ok(|(key, value)| { + let value = match value { + FlatStateValue::Ref(value) => client + .chain + .chain_store() + .store() + .trie_store() + .get(shard_uid, &value.hash) + .unwrap() + .to_vec(), + FlatStateValue::Inlined(data) => data, + }; + (key, value) + }) + .collect::, _>>() + .unwrap(); + + assert_state_equal(&memtrie_state, &flat_store_state, shard_uid, "memtrie and flat store"); + checked_shards.push(shard_uid); + } + checked_shards +} + +fn assert_state_equal( + values1: &HashSet<(Vec, Vec)>, + values2: &HashSet<(Vec, Vec)>, + shard_uid: ShardUId, + cmp_msg: &str, +) { + let diff = values1.symmetric_difference(values2); + let mut has_diff = false; + for (key, value) in diff { + has_diff = true; + tracing::error!(target: "test", ?shard_uid, key=?key, ?value, "Difference in state between {}!", cmp_msg); + } + assert!(!has_diff, "{} state mismatch!", cmp_msg); +} + +fn should_assert_state_sanity( + load_mem_tries_for_tracked_shards: bool, + is_resharded: bool, + shards_pending_resharding: &HashSet, + shard_layout: &ShardLayout, + shard_uid: &ShardUId, +) -> bool { + // Always assert if the tracked shards are loaded into memory. + if load_mem_tries_for_tracked_shards { + return true; + } + + // In the old layout do not enforce except for shards pending resharding. + if !is_resharded && !shards_pending_resharding.contains(&shard_uid) { + return false; + } + + // In the new layout do not enforce for shards that were not split. + if is_resharded && !shard_was_split(shard_layout, shard_uid.shard_id()) { + return false; + } + + true +} + +/// Asserts that all parent shard State is accessible via parent and children shards. +pub fn check_state_shard_uid_mapping_after_resharding(client: &Client, parent_shard_uid: ShardUId) { + let tip = client.chain.head().unwrap(); + let epoch_id = tip.epoch_id; + let epoch_config = client.epoch_manager.get_epoch_config(&epoch_id).unwrap(); + let children_shard_uids = + epoch_config.shard_layout.get_children_shards_uids(parent_shard_uid.shard_id()).unwrap(); + assert_eq!(children_shard_uids.len(), 2); + + let store = client.chain.chain_store.store().trie_store(); + for kv in store.store().iter_raw_bytes(DBCol::State) { + let (key, value) = kv.unwrap(); + let shard_uid = ShardUId::try_from_slice(&key[0..8]).unwrap(); + // Just after resharding, no State data must be keyed using children ShardUIds. + assert!(!children_shard_uids.contains(&shard_uid)); + if shard_uid != parent_shard_uid { + continue; + } + let node_hash = CryptoHash::try_from_slice(&key[8..]).unwrap(); + let (value, _) = decode_value_with_rc(&value); + let parent_value = store.get(parent_shard_uid, &node_hash); + // Parent shard data must still be accessible using parent ShardUId. + assert_eq!(&parent_value.unwrap()[..], value.unwrap()); + // All parent shard data is available via both children shards. + for child_shard_uid in &children_shard_uids { + let child_value = store.get(*child_shard_uid, &node_hash); + assert_eq!(&child_value.unwrap()[..], value.unwrap()); + } + } +}