Skip to content

Commit

Permalink
refactor: remove shard functions from epoch manager (#12843)
Browse files Browse the repository at this point in the history
  • Loading branch information
stedfn authored Jan 30, 2025
1 parent 79378e7 commit 8f62a08
Show file tree
Hide file tree
Showing 25 changed files with 341 additions and 474 deletions.
19 changes: 12 additions & 7 deletions chain/chain/src/chain.rs
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,10 @@ use near_async::time::{Clock, Duration, Instant};
use near_chain_configs::{MutableConfigValue, MutableValidatorSigner};
use near_chain_primitives::error::{BlockKnownError, Error, LogTransientStorageError};
use near_epoch_manager::shard_assignment::shard_id_to_uid;
use near_epoch_manager::shard_tracker::ShardTracker;
use near_epoch_manager::shard_tracker::{
get_prev_shard_id_from_prev_hash, get_prev_shard_ids, get_shard_layout_from_prev_block,
ShardTracker,
};
use near_epoch_manager::EpochManagerAdapter;
use near_primitives::bandwidth_scheduler::BandwidthRequests;
use near_primitives::block::{genesis_chunks, Block, BlockValidityError, Chunks, MaybeNew, Tip};
Expand Down Expand Up @@ -2151,8 +2154,10 @@ impl Chain {
let chunk_header = last_final_block_chunks
.get(shard_index)
.ok_or_else(|| Error::InvalidShardId(shard_uid.shard_id()))?;
let chunk_shard_layout =
self.epoch_manager.get_shard_layout_from_prev_block(chunk_header.prev_block_hash())?;
let chunk_shard_layout = get_shard_layout_from_prev_block(
self.epoch_manager.as_ref(),
chunk_header.prev_block_hash(),
)?;
let chunk_shard_uid =
ShardUId::from_shard_id_and_layout(chunk_header.shard_id(), &chunk_shard_layout);

Expand Down Expand Up @@ -3523,7 +3528,7 @@ impl Chain {
chunk_header: &ShardChunkHeader,
) -> Result<ChunkState, Error> {
let shard_layout =
self.epoch_manager.get_shard_layout_from_prev_block(prev_block.hash())?;
get_shard_layout_from_prev_block(self.epoch_manager.as_ref(), prev_block.hash())?;
let shard_id = chunk_header.shard_id();
let shard_index = shard_layout.get_shard_index(shard_id)?;
let prev_merkle_proofs =
Expand Down Expand Up @@ -3904,7 +3909,7 @@ impl Chain {
let epoch_height =
self.epoch_manager.get_epoch_height_from_prev_block(prev_prev_hash)?;
let shard_layout =
&self.epoch_manager.get_shard_layout_from_prev_block(prev_prev_hash)?;
&get_shard_layout_from_prev_block(self.epoch_manager.as_ref(), prev_prev_hash)?;
let shard_uids = shard_layout.shard_uids().enumerate().collect();

let make_snapshot_callback = &snapshot_callbacks.make_snapshot_callback;
Expand Down Expand Up @@ -4452,7 +4457,7 @@ impl Chain {
let epoch_id = epoch_manager.get_epoch_id_from_prev_block(prev_block.hash())?;
let shard_ids = epoch_manager.shard_ids(&epoch_id)?;

let prev_shard_ids = epoch_manager.get_prev_shard_ids(prev_block.hash(), shard_ids)?;
let prev_shard_ids = get_prev_shard_ids(epoch_manager, prev_block.hash(), shard_ids)?;
let prev_chunks = prev_block.chunks();
Ok(prev_shard_ids
.into_iter()
Expand All @@ -4466,7 +4471,7 @@ impl Chain {
shard_id: ShardId,
) -> Result<ShardChunkHeader, Error> {
let (_, prev_shard_id, prev_shard_index) =
epoch_manager.get_prev_shard_id_from_prev_hash(prev_block.hash(), shard_id)?;
get_prev_shard_id_from_prev_hash(epoch_manager, prev_block.hash(), shard_id)?;
Ok(prev_block
.chunks()
.get(prev_shard_index)
Expand Down
4 changes: 3 additions & 1 deletion chain/chain/src/chain_update.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ use crate::{metrics, DoomslugThresholdMode};
use crate::{Chain, Doomslug};
use near_chain_primitives::error::Error;
use near_epoch_manager::shard_assignment::shard_id_to_uid;
use near_epoch_manager::shard_tracker::get_shard_layout_from_prev_block;
use near_epoch_manager::EpochManagerAdapter;
use near_primitives::apply::ApplyChunkReason;
use near_primitives::block::{Block, Tip};
Expand Down Expand Up @@ -297,7 +298,8 @@ impl<'a> ChainUpdate<'a> {
}
}

let shard_layout = self.epoch_manager.get_shard_layout_from_prev_block(prev.hash())?;
let shard_layout =
get_shard_layout_from_prev_block(self.epoch_manager.as_ref(), &prev.hash())?;
SHARD_LAYOUT_VERSION.set(shard_layout.version() as i64);
SHARD_LAYOUT_NUM_SHARDS.set(shard_layout.shard_ids().count() as i64);
}
Expand Down
3 changes: 2 additions & 1 deletion chain/chain/src/migrations.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
use crate::store::ChainStoreAccess;
use near_chain_primitives::error::Error;
use near_epoch_manager::shard_tracker::get_prev_shard_ids;
use near_epoch_manager::EpochManagerAdapter;
use near_primitives::hash::CryptoHash;
use near_primitives::types::ShardId;
Expand Down Expand Up @@ -29,7 +30,7 @@ pub fn check_if_block_is_first_with_chunk_of_version(
if is_first_epoch_with_protocol_version(epoch_manager, prev_block_hash)? {
// Compare only epochs because we already know that current epoch is the first one with current protocol version
// convert shard id to shard id of previous epoch because number of shards may change
let (shard_id, _) = epoch_manager.get_prev_shard_ids(prev_block_hash, vec![shard_id])?[0];
let (shard_id, _) = get_prev_shard_ids(epoch_manager, prev_block_hash, vec![shard_id])?[0];
let prev_epoch_id = chain_store.get_epoch_id_of_last_block_with_chunk(
epoch_manager,
prev_block_hash,
Expand Down
14 changes: 9 additions & 5 deletions chain/chain/src/runtime/tests.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ use crate::{Chain, ChainGenesis, ChainStoreAccess, DoomslugThresholdMode};
use assert_matches::assert_matches;
use near_chain_configs::test_utils::{TESTING_INIT_BALANCE, TESTING_INIT_STAKE};
use near_epoch_manager::shard_assignment::shard_id_to_uid;
use near_epoch_manager::shard_tracker::ShardTracker;
use near_epoch_manager::shard_tracker::{get_shard_layout_from_prev_block, ShardTracker};
use near_epoch_manager::{EpochManager, RngSeed};
use near_pool::{
InsertTransactionResult, PoolIteratorWrapper, TransactionGroupIteratorWrapper, TransactionPool,
Expand Down Expand Up @@ -374,7 +374,8 @@ impl TestEnv {
.unwrap()
.commit()
.unwrap();
let shard_layout = self.epoch_manager.get_shard_layout_from_prev_block(&new_hash).unwrap();
let shard_layout =
get_shard_layout_from_prev_block(self.epoch_manager.as_ref(), &new_hash).unwrap();
let mut new_receipts = HashMap::<_, Vec<Receipt>>::new();
for receipt in all_receipts {
if receipt.send_to_all_shards() {
Expand Down Expand Up @@ -1459,7 +1460,8 @@ fn test_insufficient_stake() {
fn test_flat_state_usage() {
let env = TestEnv::new(vec![vec!["test1".parse().unwrap()]], 4, false);
let prev_hash = env.head.prev_block_hash;
let shard_layout = env.epoch_manager.get_shard_layout_from_prev_block(&prev_hash).unwrap();
let shard_layout =
get_shard_layout_from_prev_block(env.epoch_manager.as_ref(), &prev_hash).unwrap();
let shard_id = shard_layout.shard_ids().next().unwrap();
let state_root = Trie::EMPTY_ROOT;

Expand Down Expand Up @@ -1499,7 +1501,8 @@ fn test_trie_and_flat_state_equality() {
// - using state trie, which should use flat state after enabling it in the protocol
// - using view state, which should never use flat state
let prev_hash = env.head.prev_block_hash;
let shard_layout = env.epoch_manager.get_shard_layout_from_prev_block(&prev_hash).unwrap();
let shard_layout =
get_shard_layout_from_prev_block(env.epoch_manager.as_ref(), &prev_hash).unwrap();
let shard_id = shard_layout.shard_ids().next().unwrap();

let state_root = env.state_roots[0];
Expand Down Expand Up @@ -1644,7 +1647,8 @@ fn prepare_transactions(
storage_config: RuntimeStorageConfig,
) -> Result<PreparedTransactions, Error> {
let prev_hash = env.head.prev_block_hash;
let shard_layout = env.epoch_manager.get_shard_layout_from_prev_block(&prev_hash).unwrap();
let shard_layout =
get_shard_layout_from_prev_block(env.epoch_manager.as_ref(), &prev_hash).unwrap();
let shard_id = shard_layout.shard_ids().next().unwrap();
let block = chain.get_block(&prev_hash).unwrap();
let congestion_info = block.block_congestion_info();
Expand Down
19 changes: 12 additions & 7 deletions chain/chain/src/stateless_validation/chunk_validation.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,9 @@ use lru::LruCache;
use near_async::futures::AsyncComputationSpawnerExt;
use near_chain_primitives::Error;
use near_epoch_manager::shard_assignment::shard_id_to_uid;
use near_epoch_manager::shard_tracker::{
get_prev_shard_id_from_prev_hash, get_shard_layout_from_prev_block,
};
use near_epoch_manager::EpochManagerAdapter;
use near_pool::TransactionGroupIteratorWrapper;
use near_primitives::apply::ApplyChunkReason;
Expand Down Expand Up @@ -182,8 +185,7 @@ fn get_state_witness_block_range(

let initial_prev_hash = *state_witness.chunk_header.prev_block_hash();
let initial_prev_block = store.get_block(&initial_prev_hash)?;
let initial_shard_layout =
epoch_manager.get_shard_layout_from_prev_block(&initial_prev_hash)?;
let initial_shard_layout = get_shard_layout_from_prev_block(epoch_manager, &initial_prev_hash)?;
let initial_shard_id = state_witness.chunk_header.shard_id();
// Check that shard id is present in current epoch.
// TODO: consider more proper way to validate this.
Expand Down Expand Up @@ -212,7 +214,7 @@ fn get_state_witness_block_range(
implicit_transition_params.push(transition);
}
let (prev_shard_layout, prev_shard_id, prev_shard_index) =
epoch_manager.get_prev_shard_id_from_prev_hash(prev_hash, position.shard_id)?;
get_prev_shard_id_from_prev_hash(epoch_manager, prev_hash, position.shard_id)?;

let new_chunk_seen = block_has_new_chunk(&position.prev_block, prev_shard_index)?;
let new_chunks_seen_update =
Expand Down Expand Up @@ -285,7 +287,7 @@ fn get_resharding_transition(
return Ok(None);
}

let shard_layout = epoch_manager.get_shard_layout_from_prev_block(prev_header.hash())?;
let shard_layout = get_shard_layout_from_prev_block(epoch_manager, prev_header.hash())?;
let prev_epoch_id = epoch_manager.get_prev_epoch_id_from_prev_block(prev_header.hash())?;
let prev_shard_layout = epoch_manager.get_shard_layout(&prev_epoch_id)?;
let block_has_new_shard_layout = epoch_manager.is_next_block_epoch_start(prev_header.hash())?
Expand Down Expand Up @@ -553,9 +555,12 @@ fn validate_source_receipt_proofs(
receipts_to_apply.extend(proof.0.iter().cloned());
}

current_target_shard_id = epoch_manager
.get_prev_shard_id_from_prev_hash(block.header().prev_hash(), current_target_shard_id)?
.1;
current_target_shard_id = get_prev_shard_id_from_prev_hash(
epoch_manager,
block.header().prev_hash(),
current_target_shard_id,
)?
.1;
}

// Check that there are no extraneous proofs in source_receipt_proofs.
Expand Down
7 changes: 4 additions & 3 deletions chain/chain/src/store/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ use borsh::{BorshDeserialize, BorshSerialize};
use chrono::Utc;

use near_chain_primitives::error::Error;
use near_epoch_manager::shard_tracker::{get_prev_shard_ids, get_shard_layout_from_prev_block};
use near_epoch_manager::EpochManagerAdapter;
use near_primitives::block::Tip;
use near_primitives::checked_feature;
Expand Down Expand Up @@ -243,7 +244,7 @@ pub trait ChainStoreAccess {
}

let prev_hash = header.prev_hash();
let prev_shard_layout = epoch_manager.get_shard_layout_from_prev_block(prev_hash)?;
let prev_shard_layout = get_shard_layout_from_prev_block(epoch_manager, prev_hash)?;

if prev_shard_layout != current_shard_layout {
let parent_shard_id = current_shard_layout.get_parent_shard_id(current_shard_id)?;
Expand Down Expand Up @@ -362,7 +363,7 @@ pub trait ChainStoreAccess {
}
candidate_hash = *block_header.prev_hash();
(shard_id, shard_index) =
epoch_manager.get_prev_shard_ids(&candidate_hash, vec![shard_id])?[0];
get_prev_shard_ids(epoch_manager, &candidate_hash, vec![shard_id])?[0];
}
}

Expand Down Expand Up @@ -492,7 +493,7 @@ impl ChainStore {
shard_id: ShardId,
last_included_height: BlockHeight,
) -> Result<Vec<Receipt>, Error> {
let shard_layout = epoch_manager.get_shard_layout_from_prev_block(&prev_block_hash)?;
let shard_layout = get_shard_layout_from_prev_block(epoch_manager, &prev_block_hash)?;
let mut receipts_block_hash = prev_block_hash;
loop {
let block_header = self.get_block_header(&receipts_block_hash)?;
Expand Down
129 changes: 0 additions & 129 deletions chain/chain/src/test_utils/kv_runtime.rs
Original file line number Diff line number Diff line change
Expand Up @@ -564,45 +564,6 @@ impl EpochManagerAdapter for MockEpochManager {
Ok(self.get_epoch_and_valset(*parent_hash)?.2)
}

fn get_prev_shard_ids(
&self,
prev_hash: &CryptoHash,
shard_ids: Vec<ShardId>,
) -> Result<Vec<(ShardId, ShardIndex)>, Error> {
let mut prev_shard_ids = vec![];
let shard_layout = self.get_shard_layout_from_prev_block(prev_hash)?;
for shard_id in shard_ids {
// This is not correct if there was a resharding event in between
// the previous and current block.
let prev_shard_id = shard_id;
let prev_shard_index = shard_layout.get_shard_index(prev_shard_id)?;
prev_shard_ids.push((prev_shard_id, prev_shard_index));
}

Ok(prev_shard_ids)
}

fn get_prev_shard_id_from_prev_hash(
&self,
prev_hash: &CryptoHash,
shard_id: ShardId,
) -> Result<(ShardLayout, ShardId, ShardIndex), EpochError> {
let shard_layout = self.get_shard_layout_from_prev_block(prev_hash)?;
// This is not correct if there was a resharding event in between
// the previous and current block.
let prev_shard_id = shard_id;
let prev_shard_index = shard_layout.get_shard_index(prev_shard_id)?;
Ok((shard_layout, prev_shard_id, prev_shard_index))
}

fn get_shard_layout_from_prev_block(
&self,
_parent_hash: &CryptoHash,
) -> Result<ShardLayout, EpochError> {
#[allow(deprecated)]
Ok(ShardLayout::v0(self.num_shards, 0))
}

fn get_epoch_id(&self, block_hash: &CryptoHash) -> Result<EpochId, EpochError> {
let (epoch_id, _, _) = self.get_epoch_and_valset(*block_hash)?;
Ok(epoch_id)
Expand Down Expand Up @@ -886,96 +847,6 @@ impl EpochManagerAdapter for MockEpochManager {
Ok(true)
}

fn cares_about_shard_in_epoch(
&self,
epoch_id: &EpochId,
account_id: &AccountId,
shard_id: ShardId,
) -> Result<bool, EpochError> {
// This `unwrap` here tests that in all code paths we check that the epoch exists before
// we check if we care about a shard. Please do not remove the unwrap, fix the logic of
// the calling function.
let epoch_valset = self.get_valset_for_epoch(epoch_id).unwrap();
let shard_layout = self.get_shard_layout(epoch_id)?;
let shard_index = shard_layout.get_shard_index(shard_id)?;
let chunk_producers = self.get_chunk_producers(epoch_valset, shard_index);
for validator in chunk_producers {
if validator.account_id() == account_id {
return Ok(true);
}
}
Ok(false)
}

fn cares_about_shard_from_prev_block(
&self,
parent_hash: &CryptoHash,
account_id: &AccountId,
shard_id: ShardId,
) -> Result<bool, EpochError> {
// This `unwrap` here tests that in all code paths we check that the epoch exists before
// we check if we care about a shard. Please do not remove the unwrap, fix the logic of
// the calling function.
let epoch_valset = self.get_epoch_and_valset(*parent_hash).unwrap();
let shard_layout = self.get_shard_layout_from_prev_block(parent_hash)?;
let shard_index = shard_layout.get_shard_index(shard_id)?;
let chunk_producers = self.get_chunk_producers(epoch_valset.1, shard_index);
for validator in chunk_producers {
if validator.account_id() == account_id {
return Ok(true);
}
}
Ok(false)
}

fn cares_about_shard_next_epoch_from_prev_block(
&self,
parent_hash: &CryptoHash,
account_id: &AccountId,
shard_id: ShardId,
) -> Result<bool, EpochError> {
// This `unwrap` here tests that in all code paths we check that the epoch exists before
// we check if we care about a shard. Please do not remove the unwrap, fix the logic of
// the calling function.
let epoch_valset = self.get_epoch_and_valset(*parent_hash).unwrap();
let shard_layout = self.get_shard_layout_from_prev_block(parent_hash)?;
let shard_index = shard_layout.get_shard_index(shard_id)?;
let chunk_producers = self.get_chunk_producers(
(epoch_valset.1 + 1) % self.validators_by_valset.len(),
shard_index,
);
for validator in chunk_producers {
if validator.account_id() == account_id {
return Ok(true);
}
}
Ok(false)
}

fn cared_about_shard_prev_epoch_from_prev_block(
&self,
parent_hash: &CryptoHash,
account_id: &AccountId,
shard_id: ShardId,
) -> Result<bool, EpochError> {
// This `unwrap` here tests that in all code paths we check that the epoch exists before
// we check if we care about a shard. Please do not remove the unwrap, fix the logic of
// the calling function.
let epoch_valset = self.get_epoch_and_valset(*parent_hash).unwrap();
let shard_layout = self.get_shard_layout_from_prev_block(parent_hash)?;
let shard_index = shard_layout.get_shard_index(shard_id)?;
let chunk_producers = self.get_chunk_producers(
(epoch_valset.1.wrapping_sub(1)) % self.validators_by_valset.len(),
shard_index,
);
for validator in chunk_producers {
if validator.account_id() == account_id {
return Ok(true);
}
}
Ok(false)
}

fn will_shard_layout_change(&self, parent_hash: &CryptoHash) -> Result<bool, EpochError> {
// Copied from EpochManager (KeyValueRuntime is deprecated anyway).
let epoch_id = self.get_epoch_id_from_prev_block(parent_hash)?;
Expand Down
Loading

0 comments on commit 8f62a08

Please sign in to comment.