Skip to content

Commit

Permalink
fix(resharding) - fix the buffered receipts forwarding (#12561)
Browse files Browse the repository at this point in the history
The new logic introduced recently in #12538 relies on some assumptions
that do not hold in the ShardLayoutV1. It's important to keep this shard
layout working as this is what we currently have in mainnet.

The issue was that in the V1 there may be shared shard ids between the
parent shard ids that were split and the current shard ids. In this case
the same buffer would be processed twice and crash the node.

The fix is to make it depend on the protocol version. For old protocol
versions we skip forwarding receipts from parents - there should be any
before the new resharding. Starting from the new resharding it is ok to
assume unique shard ids so the new logic works.

For reason that I can't remember the `test_resharding_v3_load_mem_trie`
test broke and I needed to make some adjustments to it. Maybe it's
because the new base shard layout is non-trivial now.
  • Loading branch information
wacban authored Dec 6, 2024
1 parent 09ed12a commit c151baa
Show file tree
Hide file tree
Showing 5 changed files with 131 additions and 46 deletions.
5 changes: 5 additions & 0 deletions chain/chain/src/resharding/event_type.rs
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,11 @@ impl ReshardingEventType {
return log_and_error("can't perform two reshardings at the same time!");
}
// Parent shard is no longer part of this shard layout.
//
// Please note the use of the next shard layout version.
// Technically speaking the current shard layout version
// should be used for the parent. However since
// ShardLayoutV2 the version is frozen so it is ok.
let parent_shard = ShardUId::new(next_shard_layout.version(), *parent_id);
let left_child_shard =
ShardUId::from_shard_id_and_layout(children_ids[0], next_shard_layout);
Expand Down
19 changes: 18 additions & 1 deletion core/primitives/src/shard_layout.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ use near_schema_checker_lib::ProtocolSchema;
use rand::rngs::StdRng;
use rand::seq::SliceRandom;
use rand::SeedableRng;
use std::collections::BTreeMap;
use std::collections::{BTreeMap, BTreeSet};
use std::{fmt, str};

/// This file implements two data structure `ShardLayout` and `ShardUId`
Expand Down Expand Up @@ -784,6 +784,23 @@ impl ShardLayout {
.ok_or(ShardLayoutError::InvalidShardIndexError { shard_index }),
}
}

/// Returns all of the shards from the previous shard layout that were
/// split into multiple shards in this shard layout.
pub fn get_parent_shard_ids(&self) -> Result<BTreeSet<ShardId>, ShardLayoutError> {
let mut parent_shard_ids = BTreeSet::new();
for shard_id in self.shard_ids() {
let parent_shard_id = self.try_get_parent_shard_id(shard_id)?;
let Some(parent_shard_id) = parent_shard_id else {
continue;
};
if parent_shard_id == shard_id {
continue;
}
parent_shard_ids.insert(parent_shard_id);
}
Ok(parent_shard_ids)
}
}

/// Maps an account to the shard that it belongs to given a shard_layout
Expand Down
24 changes: 17 additions & 7 deletions integration-tests/src/test_loop/tests/fix_min_stake_ratio.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ use near_primitives::num_rational::Rational32;
use near_primitives::test_utils::create_user_test_signer;
use near_primitives::transaction::SignedTransaction;
use near_primitives::types::AccountId;
use near_primitives::upgrade_schedule::ProtocolUpgradeVotingSchedule;
use near_primitives_core::version::ProtocolFeature;
use std::string::ToString;
use std::sync::atomic::{AtomicU64, Ordering};
Expand All @@ -24,7 +25,18 @@ use std::sync::atomic::{AtomicU64, Ordering};
#[test]
fn slow_test_fix_min_stake_ratio() {
init_test_logger();
let builder = TestLoopBuilder::new();

// Take epoch configuration before the protocol upgrade, where minimum
// stake ratio was 1/6250.
let epoch_config_store = EpochConfigStore::for_chain_id("mainnet", None).unwrap();
let target_protocol_version = ProtocolFeature::FixMinStakeRatio.protocol_version();
let genesis_protocol_version = target_protocol_version - 1;

// Immediately start voting for the new protocol version
let protocol_upgrade_schedule =
ProtocolUpgradeVotingSchedule::new_immediate(target_protocol_version);

let builder = TestLoopBuilder::new().protocol_upgrade_schedule(protocol_upgrade_schedule);

let initial_balance = 1_000_000 * ONE_NEAR;
let epoch_length = 10;
Expand Down Expand Up @@ -56,18 +68,16 @@ fn slow_test_fix_min_stake_ratio() {
},
];

// Take epoch configuration before the protocol upgrade, where minimum
// stake ratio was 1/6250.
let epoch_config_store = EpochConfigStore::for_chain_id("mainnet", None).unwrap();
let protocol_version = ProtocolFeature::FixMinStakeRatio.protocol_version() - 1;
let shard_layout =
epoch_config_store.get_config(genesis_protocol_version).as_ref().shard_layout.clone();

// Create chain with version before FixMinStakeRatio was enabled.
// Check that the small validator is not included in the validator set.
let mut genesis_builder = TestGenesisBuilder::new();
genesis_builder
.genesis_time_from_clock(&builder.clock())
.shard_layout(epoch_config_store.get_config(protocol_version).as_ref().shard_layout.clone())
.protocol_version(protocol_version)
.shard_layout(shard_layout)
.protocol_version(genesis_protocol_version)
.epoch_length(epoch_length)
.validators_raw(validators, 1, 1, 2)
// Disable validator rewards.
Expand Down
98 changes: 77 additions & 21 deletions integration-tests/src/test_loop/tests/resharding_v3.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ use near_primitives::epoch_manager::EpochConfigStore;
use near_primitives::hash::CryptoHash;
use near_primitives::shard_layout::{account_id_to_shard_uid, ShardLayout};
use near_primitives::state_record::StateRecord;
use near_primitives::types::{AccountId, BlockHeightDelta, EpochId, Gas, ShardId};
use near_primitives::types::{AccountId, BlockHeightDelta, EpochId, Gas, NumShards, ShardId};
use near_primitives::version::{ProtocolFeature, PROTOCOL_VERSION};
use near_store::adapter::StoreAdapter;
use near_store::db::refcount::decode_value_with_rc;
Expand Down Expand Up @@ -590,24 +590,50 @@ fn shard_was_split(shard_layout: &ShardLayout, shard_id: ShardId) -> bool {
parent != shard_id
}

/// Asserts that for each child shard:
/// MemTrie, FlatState and DiskTrie all contain the same key-value pairs.
/// If `load_mem_tries_for_tracked_shards` is false, we only enforce memtries for split shards
/// Returns the ShardUIds that this client tracks and has sane memtries and flat storage for
/// Asserts that for each child shard, MemTrie, FlatState and DiskTrie all
/// contain the same key-value pairs. If `load_mem_tries_for_tracked_shards` is
/// false, we only enforce memtries for shards pending resharding in the old
/// layout and the shards thet were split in the new shard layout.
///
/// Returns the ShardUIds that this client tracks and has sane memtries and flat
/// storage for
///
/// The new num shards argument is a clumsy way to check if the head is before
/// or after resharding.
fn assert_state_sanity(
client: &Client,
final_head: &Tip,
load_mem_tries_for_tracked_shards: bool,
new_num_shards: NumShards,
) -> Vec<ShardUId> {
let shard_layout = client.epoch_manager.get_shard_layout(&final_head.epoch_id).unwrap();
let is_resharded = shard_layout.num_shards() == new_num_shards;
let mut checked_shards = Vec::new();

let protocol_version =
client.epoch_manager.get_epoch_protocol_version(&final_head.epoch_id).unwrap();
let shards_pending_resharding = client
.epoch_manager
.get_shard_uids_pending_resharding(protocol_version, PROTOCOL_VERSION)
.unwrap();

for shard_uid in shard_layout.shard_uids() {
if !load_mem_tries_for_tracked_shards
&& !shard_was_split(&shard_layout, shard_uid.shard_id())
{
continue;
// TODO - the condition for checks is duplicated in the
// `get_epoch_check` method, refactor this.
if !load_mem_tries_for_tracked_shards {
// In the old layout do not enforce except for shards pending resharding.
if !is_resharded && !shards_pending_resharding.contains(&shard_uid) {
tracing::debug!(target: "test", ?shard_uid, "skipping shard not pending resharding");
continue;
}

// In the new layout do not enforce for shards that were not split.
if is_resharded && !shard_was_split(&shard_layout, shard_uid.shard_id()) {
tracing::debug!(target: "test", ?shard_uid, "skipping shard not split");
continue;
}
}

if !client_tracking_shard(client, shard_uid.shard_id(), &final_head.prev_block_hash) {
continue;
}
Expand Down Expand Up @@ -696,22 +722,44 @@ impl TrieSanityCheck {
}

// If it's not already stored, initialize it with the expected ShardUIds for each account
fn get_epoch_check(&mut self, client: &Client, tip: &Tip) -> &mut EpochTrieCheck {
fn get_epoch_check(
&mut self,
client: &Client,
tip: &Tip,
new_num_shards: NumShards,
) -> &mut EpochTrieCheck {
let protocol_version =
client.epoch_manager.get_epoch_protocol_version(&tip.epoch_id).unwrap();
let shards_pending_resharding = client
.epoch_manager
.get_shard_uids_pending_resharding(protocol_version, PROTOCOL_VERSION)
.unwrap();
let shard_layout = client.epoch_manager.get_shard_layout(&tip.epoch_id).unwrap();
let is_resharded = shard_layout.num_shards() == new_num_shards;

match self.checks.entry(tip.epoch_id) {
std::collections::hash_map::Entry::Occupied(e) => e.into_mut(),
std::collections::hash_map::Entry::Vacant(e) => {
let shard_layout = client.epoch_manager.get_shard_layout(&tip.epoch_id).unwrap();
let shard_uids = shard_layout.shard_uids().collect_vec();
let mut check = HashMap::new();
for account_id in self.accounts.iter() {
let tracked = shard_uids
.iter()
.filter_map(|uid| {
if !self.load_mem_tries_for_tracked_shards
if !is_resharded
&& !self.load_mem_tries_for_tracked_shards
&& !shards_pending_resharding.contains(uid)
{
return None;
}

if is_resharded
&& !self.load_mem_tries_for_tracked_shards
&& !shard_was_split(&shard_layout, uid.shard_id())
{
return None;
}

let cares = client.shard_tracker.care_about_shard(
Some(account_id),
&tip.prev_block_hash,
Expand All @@ -733,7 +781,7 @@ impl TrieSanityCheck {
}

// Check trie sanity and keep track of which shards were succesfully fully checked
fn assert_state_sanity(&mut self, clients: &[&Client]) {
fn assert_state_sanity(&mut self, clients: &[&Client], new_num_shards: NumShards) {
for client in clients {
let signer = client.validator_signer.get();
let Some(account_id) = signer.as_ref().map(|s| s.validator_id()) else {
Expand All @@ -753,9 +801,13 @@ impl TrieSanityCheck {
if head.epoch_id != final_head.epoch_id {
continue;
}
let checked_shards =
assert_state_sanity(client, &final_head, self.load_mem_tries_for_tracked_shards);
let check = self.get_epoch_check(client, &head);
let checked_shards = assert_state_sanity(
client,
&final_head,
self.load_mem_tries_for_tracked_shards,
new_num_shards,
);
let check = self.get_epoch_check(client, &head, new_num_shards);
let check = check.get_mut(account_id).unwrap();
for shard_uid in checked_shards {
check.insert(shard_uid, true);
Expand Down Expand Up @@ -834,8 +886,13 @@ fn test_resharding_v3_base(params: TestReshardingParameters) {
base_epoch_config.chunk_validator_only_kickout_threshold = 0;
}

// Set the base shard layout to V1 with non-unique shard ids. This is what
// we have on mainnet before the ReshardingV3 release.
// TODO(resharding) test both V1->V2 and V2->V2
let boundary_accounts = vec!["account1".parse().unwrap(), "account3".parse().unwrap()];
let base_shard_layout = ShardLayout::multi_shard_custom(boundary_accounts, 3);
let split_map = vec![vec![ShardId::new(0), ShardId::new(1), ShardId::new(2)]];
#[allow(deprecated)]
let base_shard_layout = ShardLayout::v1(boundary_accounts, Some(split_map), 3);

base_epoch_config.shard_layout = base_shard_layout.clone();
let new_boundary_account = "account6".parse().unwrap();
Expand All @@ -846,7 +903,7 @@ fn test_resharding_v3_base(params: TestReshardingParameters) {
ShardLayout::derive_shard_layout(&base_shard_layout, new_boundary_account);
tracing::info!(target: "test", ?base_shard_layout, new_shard_layout=?epoch_config.shard_layout, "shard layout");

let expected_num_shards = epoch_config.shard_layout.shard_ids().count();
let expected_num_shards = epoch_config.shard_layout.num_shards();
let epoch_config_store = EpochConfigStore::test(BTreeMap::from_iter(vec![
(base_protocol_version, Arc::new(base_epoch_config)),
(base_protocol_version + 1, Arc::new(epoch_config)),
Expand Down Expand Up @@ -945,9 +1002,8 @@ fn test_resharding_v3_base(params: TestReshardingParameters) {
println!("State before resharding:");
print_and_assert_shard_accounts(&clients, &tip);
}
trie_sanity_check.assert_state_sanity(&clients);
trie_sanity_check.assert_state_sanity(&clients, expected_num_shards);
latest_block_height.set(tip.height);
println!("block: {} chunks: {:?}", tip.height, block_header.chunk_mask());
if params.all_chunks_expected && params.chunk_ranges_to_drop.is_empty() {
assert!(block_header.chunk_mask().iter().all(|chunk_bit| *chunk_bit));
}
Expand All @@ -960,7 +1016,7 @@ fn test_resharding_v3_base(params: TestReshardingParameters) {
let prev_epoch_id =
client.epoch_manager.get_prev_epoch_id_from_prev_block(&tip.prev_block_hash).unwrap();
let epoch_config = client.epoch_manager.get_epoch_config(&prev_epoch_id).unwrap();
if epoch_config.shard_layout.shard_ids().count() != expected_num_shards {
if epoch_config.shard_layout.num_shards() != expected_num_shards {
return false;
}

Expand Down
31 changes: 14 additions & 17 deletions runtime/runtime/src/congestion_control.rs
Original file line number Diff line number Diff line change
Expand Up @@ -234,24 +234,21 @@ impl ReceiptSinkV2 {

let protocol_version = apply_state.current_protocol_version;
let shard_layout = epoch_info_provider.shard_layout(&apply_state.epoch_id)?;
let shard_ids = if ProtocolFeature::SimpleNightshadeV4.enabled(protocol_version) {
shard_layout.shard_ids().collect_vec()
} else {
self.outgoing_limit.keys().copied().collect_vec()
};

let mut parent_shard_ids = BTreeSet::new();
for &shard_id in &shard_ids {
let parent_shard_id =
shard_layout.try_get_parent_shard_id(shard_id).map_err(Into::<EpochError>::into)?;
let Some(parent_shard_id) = parent_shard_id else {
continue;
let (shard_ids, parent_shard_ids) =
if ProtocolFeature::SimpleNightshadeV4.enabled(protocol_version) {
(
shard_layout.shard_ids().collect_vec(),
shard_layout.get_parent_shard_ids().map_err(Into::<EpochError>::into)?,
)
} else {
(self.outgoing_limit.keys().copied().collect_vec(), BTreeSet::new())
};
if parent_shard_id == shard_id {
continue;
}
parent_shard_ids.insert(parent_shard_id);
}

// There mustn't be any shard ids in both the parents and the current
// shard ids. If this happens the same buffer will be processed twice.
debug_assert!(
parent_shard_ids.intersection(&shard_ids.clone().into_iter().collect()).count() == 0
);

// First forward any receipts that may still be in the outgoing buffers
// of the parent shards.
Expand Down

0 comments on commit c151baa

Please sign in to comment.