Skip to content

Commit

Permalink
refactor: same treatment for get_shard_uids() as for shard_ids() (#…
Browse files Browse the repository at this point in the history
…10305)

- renames `get_shard_uids()` to `shard_uids()`
- `shard_uids()` return an iterator instead of a vector.
  • Loading branch information
akhi3030 authored Dec 7, 2023
1 parent 8985a96 commit abf6e99
Show file tree
Hide file tree
Showing 18 changed files with 56 additions and 57 deletions.
9 changes: 5 additions & 4 deletions chain/chain/src/chain.rs
Original file line number Diff line number Diff line change
Expand Up @@ -711,8 +711,7 @@ impl Chain {
let flat_storage_manager = runtime_adapter.get_flat_storage_manager();
let genesis_epoch_id = genesis.header().epoch_id();
let mut tmp_store_update = store_update.store().store_update();
for shard_uid in epoch_manager.get_shard_layout(genesis_epoch_id)?.get_shard_uids()
{
for shard_uid in epoch_manager.get_shard_layout(genesis_epoch_id)?.shard_uids() {
flat_storage_manager.set_flat_storage_for_genesis(
&mut tmp_store_update,
shard_uid,
Expand All @@ -737,7 +736,8 @@ impl Chain {
// TODO(#9511): The calculation of shard UIDs is not precise in the case
// of resharding. We need to revisit this.
let tip = store.head()?;
let shard_uids = epoch_manager.get_shard_layout(&tip.epoch_id)?.get_shard_uids();
let shard_uids: Vec<_> =
epoch_manager.get_shard_layout(&tip.epoch_id)?.shard_uids().collect();
runtime_adapter.load_mem_tries_on_startup(&shard_uids)?;

info!(target: "chain", "Init: header head @ #{} {}; block head @ #{} {}",
Expand Down Expand Up @@ -4097,7 +4097,8 @@ impl Chain {
let shard_uids = self
.epoch_manager
.get_shard_layout_from_prev_block(&head.prev_block_hash)?
.get_shard_uids();
.shard_uids()
.collect();
let last_block = self.get_block(&head.last_block_hash)?;
let make_snapshot_callback = &snapshot_callbacks.make_snapshot_callback;
make_snapshot_callback(head.prev_block_hash, epoch_height, shard_uids, last_block);
Expand Down
8 changes: 4 additions & 4 deletions chain/chain/src/store.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2312,7 +2312,7 @@ impl<'a> ChainStoreUpdate<'a> {
let shard_layout =
epoch_manager.get_shard_layout(block_header.epoch_id()).expect("epoch info must exist");
// gc shards in this epoch
let mut shard_uids_to_gc: Vec<_> = shard_layout.get_shard_uids();
let mut shard_uids_to_gc: Vec<_> = shard_layout.shard_uids().collect();
// gc shards in the shard layout in the next epoch if shards will change in the next epoch
// Suppose shard changes at epoch T, we need to garbage collect the new shard layout
// from the last block in epoch T-2 to the last block in epoch T-1
Expand All @@ -2322,7 +2322,7 @@ impl<'a> ChainStoreUpdate<'a> {
let next_shard_layout =
epoch_manager.get_shard_layout(next_epoch_id).expect("epoch info must exist");
if shard_layout != next_shard_layout {
shard_uids_to_gc.extend(next_shard_layout.get_shard_uids());
shard_uids_to_gc.extend(next_shard_layout.shard_uids());
}
shard_uids_to_gc
}
Expand Down Expand Up @@ -2368,7 +2368,7 @@ impl<'a> ChainStoreUpdate<'a> {

// Now we can proceed to removing the trie state and flat state
let mut store_update = self.store().store_update();
for shard_uid in prev_shard_layout.get_shard_uids() {
for shard_uid in prev_shard_layout.shard_uids() {
tracing::info!(target: "garbage_collection", ?block_hash, ?shard_uid, "GC resharding");
runtime.get_tries().delete_trie_for_shard(shard_uid, &mut store_update);
runtime
Expand Down Expand Up @@ -2939,7 +2939,7 @@ impl<'a> ChainStoreUpdate<'a> {
.block_extras
.insert(*block_hash, source_store.get_block_extra(block_hash)?);
let shard_layout = source_epoch_manager.get_shard_layout(&header.epoch_id())?;
for shard_uid in shard_layout.get_shard_uids() {
for shard_uid in shard_layout.shard_uids() {
chain_store_update.chain_store_cache_update.chunk_extras.insert(
(*block_hash, shard_uid),
source_store.get_chunk_extra(block_hash, &shard_uid)?.clone(),
Expand Down
3 changes: 1 addition & 2 deletions chain/chunks/src/client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -123,8 +123,7 @@ impl ShardedTransactionPool {

let mut transactions = vec![];

let old_shard_uids = old_shard_layout.get_shard_uids();
for old_shard_uid in old_shard_uids {
for old_shard_uid in old_shard_layout.shard_uids() {
if let Some(mut iter) = self.get_pool_iterator(old_shard_uid) {
while let Some(group) = iter.next() {
while let Some(tx) = group.next() {
Expand Down
2 changes: 1 addition & 1 deletion chain/client/src/client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -291,7 +291,7 @@ impl Client {
epoch_manager.get_shard_layout(&epoch_id).expect("Cannot get shard layout.");
match state_sync_adapter.write() {
Ok(mut state_sync_adapter) => {
for shard_uid in shard_layout.get_shard_uids() {
for shard_uid in shard_layout.shard_uids() {
state_sync_adapter.start(shard_uid);
}
}
Expand Down
9 changes: 4 additions & 5 deletions core/primitives/src/shard_layout.rs
Original file line number Diff line number Diff line change
Expand Up @@ -227,11 +227,10 @@ impl ShardLayout {
0..self.num_shards()
}

/// Returns shard uids for all shards in the shard layout
pub fn get_shard_uids(&self) -> Vec<ShardUId> {
self.shard_ids()
.map(|shard_id| ShardUId::from_shard_id_and_layout(shard_id, self))
.collect()
/// Returns an iterator that iterates over all the shard uids for all the
/// shards in the shard layout
pub fn shard_uids(&self) -> impl Iterator<Item = ShardUId> + '_ {
self.shard_ids().map(|shard_id| ShardUId::from_shard_id_and_layout(shard_id, self))
}
}

Expand Down
7 changes: 3 additions & 4 deletions core/store/src/cold_storage.rs
Original file line number Diff line number Diff line change
Expand Up @@ -286,14 +286,13 @@ fn get_keys_from_store(
.map(|shard_id| shard_id.to_le_bytes().to_vec())
.collect(),
DBKeyType::ShardUId => shard_layout
.get_shard_uids()
.iter()
.map(|uid| uid.to_bytes().to_vec())
.shard_uids()
.map(|shard_uid| shard_uid.to_bytes().to_vec())
.collect(),
// TODO: don't write values of State column to cache. Write them directly to colddb.
DBKeyType::TrieNodeOrValueHash => {
let mut keys = vec![];
for shard_uid in shard_layout.get_shard_uids() {
for shard_uid in shard_layout.shard_uids() {
let shard_uid_key = shard_uid.to_bytes();

debug_assert_eq!(
Expand Down
6 changes: 3 additions & 3 deletions core/store/src/flat/inlining_migration.rs
Original file line number Diff line number Diff line change
Expand Up @@ -323,7 +323,7 @@ mod tests {
#[test]
fn full_migration() {
let store = NodeStorage::test_opener().1.open().unwrap().get_hot_store();
let shard_uid = ShardLayout::v0_single_shard().get_shard_uids()[0];
let shard_uid = ShardLayout::v0_single_shard().shard_uids().next().unwrap();
let values = [
vec![0],
vec![1],
Expand Down Expand Up @@ -364,7 +364,7 @@ mod tests {
fn block_migration() {
init_test_logger();
let store = NodeStorage::test_opener().1.open().unwrap().get_hot_store();
let shard_uid = ShardLayout::v0_single_shard().get_shard_uids()[0];
let shard_uid = ShardLayout::v0_single_shard().shard_uids().next().unwrap();
let values = [
vec![0],
vec![1],
Expand Down Expand Up @@ -403,7 +403,7 @@ mod tests {
fn interrupt_blocked_migration() {
init_test_logger();
let store = NodeStorage::test_opener().1.open().unwrap().get_hot_store();
let shard_uid = ShardLayout::v0_single_shard().get_shard_uids()[0];
let shard_uid = ShardLayout::v0_single_shard().shard_uids().next().unwrap();
let values = [
vec![0],
vec![1],
Expand Down
2 changes: 1 addition & 1 deletion core/store/src/genesis/initialization.rs
Original file line number Diff line number Diff line change
Expand Up @@ -118,7 +118,7 @@ fn genesis_state_from_genesis(store: Store, genesis: &Genesis) -> Vec<StateRoot>
let tries = ShardTries::new(
store.clone(),
TrieConfig::default(),
&genesis.config.shard_layout.get_shard_uids(),
&genesis.config.shard_layout.shard_uids().collect::<Vec<_>>(),
FlatStorageManager::new(store),
StateSnapshotConfig::default(),
);
Expand Down
12 changes: 6 additions & 6 deletions integration-tests/src/tests/client/flat_storage.rs
Original file line number Diff line number Diff line change
Expand Up @@ -126,7 +126,7 @@ fn wait_for_flat_storage_creation(
fn test_flat_storage_creation_sanity() {
init_test_logger();
let genesis = Genesis::test(vec!["test0".parse().unwrap()], 1);
let shard_uid = genesis.config.shard_layout.get_shard_uids()[0];
let shard_uid = genesis.config.shard_layout.shard_uids().next().unwrap();
let store = create_test_store();

// Process some blocks with flat storage. Then remove flat storage data from disk.
Expand Down Expand Up @@ -249,7 +249,7 @@ fn test_flat_storage_creation_two_shards() {
let num_shards = 2;
let genesis =
Genesis::test_sharded_new_version(vec!["test0".parse().unwrap()], 1, vec![1; num_shards]);
let shard_uids = genesis.config.shard_layout.get_shard_uids();
let shard_uids: Vec<_> = genesis.config.shard_layout.shard_uids().collect();
let store = create_test_store();

// Process some blocks with flat storages for two shards. Then remove flat storage data from disk for shard 0.
Expand All @@ -271,7 +271,7 @@ fn test_flat_storage_creation_two_shards() {
assert_eq!(env.clients[0].process_tx(tx, false, false), ProcessTxResponse::ValidTx);
}

for &shard_uid in &shard_uids {
for &shard_uid in shard_uids.iter() {
assert_matches!(
store_helper::get_flat_storage_status(&store, shard_uid),
Ok(FlatStorageStatus::Ready(_))
Expand Down Expand Up @@ -310,7 +310,7 @@ fn test_flat_storage_creation_start_from_state_part() {
let accounts =
(0..4).map(|i| AccountId::from_str(&format!("test{}", i)).unwrap()).collect::<Vec<_>>();
let genesis = Genesis::test(accounts, 1);
let shard_uid = genesis.config.shard_layout.get_shard_uids()[0];
let shard_uid = genesis.config.shard_layout.shard_uids().next().unwrap();
let store = create_test_store();

// Process some blocks with flat storage.
Expand Down Expand Up @@ -415,7 +415,7 @@ fn test_catchup_succeeds_even_if_no_new_blocks() {
init_test_logger();
let genesis = Genesis::test(vec!["test0".parse().unwrap()], 1);
let store = create_test_store();
let shard_uid = ShardLayout::v0_single_shard().get_shard_uids()[0];
let shard_uid = ShardLayout::v0_single_shard().shard_uids().next().unwrap();

// Process some blocks with flat storage. Then remove flat storage data from disk.
{
Expand Down Expand Up @@ -511,7 +511,7 @@ fn test_not_supported_block() {
init_test_logger();
let genesis = Genesis::test(vec!["test0".parse().unwrap()], 1);
let shard_layout = ShardLayout::v0_single_shard();
let shard_uid = shard_layout.get_shard_uids()[0];
let shard_uid = shard_layout.shard_uids().next().unwrap();
let store = create_test_store();

let mut env = setup_env(&genesis, store);
Expand Down
9 changes: 5 additions & 4 deletions integration-tests/src/tests/client/resharding.rs
Original file line number Diff line number Diff line change
Expand Up @@ -69,10 +69,11 @@ fn get_genesis_protocol_version(resharding_type: &ReshardingType) -> ProtocolVer
}

fn get_parent_shard_uids(resharding_type: &ReshardingType) -> Vec<ShardUId> {
match resharding_type {
ReshardingType::V1 => ShardLayout::v0_single_shard().get_shard_uids(),
ReshardingType::V2 => ShardLayout::get_simple_nightshade_layout().get_shard_uids(),
}
let shard_layout = match resharding_type {
ReshardingType::V1 => ShardLayout::v0_single_shard(),
ReshardingType::V2 => ShardLayout::get_simple_nightshade_layout(),
};
shard_layout.shard_uids().collect()
}

// Return the expected number of shards.
Expand Down
14 changes: 7 additions & 7 deletions nearcore/src/entity_debug.rs
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ impl EntityDebugHandlerImpl {
match query {
EntityQuery::AllShardsByEpochId { epoch_id } => {
let shard_layout = self.epoch_manager.get_shard_layout(&epoch_id)?;
Ok(serialize_entity(&shard_layout.get_shard_uids()))
Ok(serialize_entity(&shard_layout.shard_uids().collect::<Vec<_>>()))
}
EntityQuery::BlockByHash { block_hash } => {
let block = self
Expand Down Expand Up @@ -183,9 +183,9 @@ impl EntityDebugHandlerImpl {
}
EntityQuery::ShardUIdByShardId { shard_id, epoch_id } => {
let shard_layout = self.epoch_manager.get_shard_layout(&epoch_id)?;
let shard_uid = *shard_layout
.get_shard_uids()
.get(shard_id as usize)
let shard_uid = shard_layout
.shard_uids()
.nth(shard_id as usize)
.ok_or_else(|| anyhow!("Shard {} not found", shard_id))?;
Ok(serialize_entity(&shard_uid))
}
Expand Down Expand Up @@ -324,9 +324,9 @@ impl EntityDebugHandlerImpl {
let shard_layout = self
.epoch_manager
.get_shard_layout_from_prev_block(&chunk.cloned_header().prev_block_hash())?;
let shard_uid = *shard_layout
.get_shard_uids()
.get(chunk.shard_id() as usize)
let shard_uid = shard_layout
.shard_uids()
.nth(chunk.shard_id() as usize)
.ok_or_else(|| anyhow!("Shard {} not found", chunk.shard_id()))?;
let path =
TriePath { path: vec![], shard_uid, state_root: chunk.prev_state_root() };
Expand Down
5 changes: 3 additions & 2 deletions nearcore/src/runtime/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -137,18 +137,19 @@ impl NightshadeRuntime {
let runtime = Runtime::new();
let trie_viewer = TrieViewer::new(trie_viewer_state_size_limit, max_gas_burnt_view);
let flat_storage_manager = FlatStorageManager::new(store.clone());
let shard_uids: Vec<_> = genesis_config.shard_layout.shard_uids().collect();
let tries = ShardTries::new(
store.clone(),
trie_config,
&genesis_config.shard_layout.get_shard_uids(),
&shard_uids,
flat_storage_manager,
state_snapshot_config,
);
if let Err(err) = tries.maybe_open_state_snapshot(|prev_block_hash: CryptoHash| {
let epoch_manager = epoch_manager.read();
let epoch_id = epoch_manager.get_epoch_id(&prev_block_hash)?;
let shard_layout = epoch_manager.get_shard_layout(&epoch_id)?;
Ok(shard_layout.get_shard_uids())
Ok(shard_layout.shard_uids().collect())
}) {
tracing::error!(target: "runtime", ?err, "Failed to check if a state snapshot exists");
}
Expand Down
4 changes: 1 addition & 3 deletions nearcore/src/runtime/tests.rs
Original file line number Diff line number Diff line change
Expand Up @@ -193,9 +193,7 @@ impl TestEnv {
// Create flat storage. Naturally it happens on Chain creation, but here we test only Runtime behaviour
// and use a mock chain, so we need to initialize flat storage manually.
let flat_storage_manager = runtime.get_flat_storage_manager();
for shard_uid in
epoch_manager.get_shard_layout(&EpochId::default()).unwrap().get_shard_uids()
{
for shard_uid in epoch_manager.get_shard_layout(&EpochId::default()).unwrap().shard_uids() {
let mut store_update = store.store_update();
flat_storage_manager.set_flat_storage_for_genesis(
&mut store_update,
Expand Down
10 changes: 5 additions & 5 deletions tools/database/src/corrupt.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,18 +18,18 @@ impl CorruptStateSnapshotCommand {

let mut store_update = store.store_update();
// TODO(resharding) automatically detect the shard version
let shard_uids = match self.shard_layout_version {
0 => ShardLayout::v0(1, 0).get_shard_uids(),
1 => ShardLayout::get_simple_nightshade_layout().get_shard_uids(),
2 => ShardLayout::get_simple_nightshade_layout_v2().get_shard_uids(),
let shard_layout = match self.shard_layout_version {
0 => ShardLayout::v0(1, 0),
1 => ShardLayout::get_simple_nightshade_layout(),
2 => ShardLayout::get_simple_nightshade_layout_v2(),
_ => {
return Err(anyhow!(
"Unsupported shard layout version! {}",
self.shard_layout_version
))
}
};
for shard_uid in shard_uids {
for shard_uid in shard_layout.shard_uids() {
corrupt(&mut store_update, &flat_storage_manager, shard_uid)?;
}
store_update.commit().unwrap();
Expand Down
7 changes: 4 additions & 3 deletions tools/database/src/state_perf.rs
Original file line number Diff line number Diff line change
Expand Up @@ -161,11 +161,12 @@ impl PerfContext {

fn generate_state_requests(store: Store, samples: usize) -> Vec<(ShardUId, ValueRef)> {
eprintln!("Generate {samples} requests to State");
let shard_uids = ShardLayout::get_simple_nightshade_layout().get_shard_uids();
let shard_uids = ShardLayout::get_simple_nightshade_layout().shard_uids().collect::<Vec<_>>();
let num_shards = shard_uids.len();
let mut ret = Vec::new();
let progress = ProgressBar::new(samples as u64);
for &shard_uid in &shard_uids {
let shard_samples = samples / shard_uids.len();
for shard_uid in shard_uids {
let shard_samples = samples / num_shards;
let mut keys_read = std::collections::HashSet::new();
for value_ref in iter_flat_state_entries(shard_uid, &store, None, None)
.flat_map(|res| res.map(|(_, value)| value.to_value_ref()))
Expand Down
2 changes: 1 addition & 1 deletion tools/fork-network/src/cli.rs
Original file line number Diff line number Diff line change
Expand Up @@ -237,7 +237,7 @@ impl ForkNetworkCommand {
EpochManager::new_arc_handle(store.clone(), &near_config.genesis.config);
let head = store.get_ser::<Tip>(DBCol::BlockMisc, FINAL_HEAD_KEY)?.unwrap();
let shard_layout = epoch_manager.get_shard_layout(&head.epoch_id)?;
let all_shard_uids = shard_layout.get_shard_uids();
let all_shard_uids: Vec<_> = shard_layout.shard_uids().collect();
let num_shards = all_shard_uids.len();
// Flat state can be at different heights for different shards.
// That is fine, we'll simply lookup state root for each .
Expand Down
2 changes: 1 addition & 1 deletion tools/fork-network/src/storage_mutator.rs
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ impl StorageMutator {

pub(crate) fn commit(self) -> anyhow::Result<Vec<StateRoot>> {
let shard_layout = self.epoch_manager.get_shard_layout(&self.epoch_id)?;
let all_shard_uids = shard_layout.get_shard_uids();
let all_shard_uids = shard_layout.shard_uids();
let mut state_roots = vec![];
for (mutator, shard_uid) in self.mutators.into_iter().zip(all_shard_uids.into_iter()) {
let state_root = mutator.commit(&shard_uid, 0)?;
Expand Down
2 changes: 1 addition & 1 deletion tools/state-viewer/src/commands.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1082,7 +1082,7 @@ pub(crate) fn print_state_stats(home_dir: &Path, store: Store, near_config: Near
let shard_layout = epoch_manager.get_shard_layout_from_prev_block(&block_hash).unwrap();

let flat_storage_manager = runtime.get_flat_storage_manager();
for shard_uid in shard_layout.get_shard_uids() {
for shard_uid in shard_layout.shard_uids() {
print_state_stats_for_shard_uid(&store, &flat_storage_manager, block_hash, shard_uid);
}
}
Expand Down

0 comments on commit abf6e99

Please sign in to comment.