Skip to content

Commit

Permalink
fix(resharding) - error handling in shard layout part 1 (#12372)
Browse files Browse the repository at this point in the history
Added error handling to `get_shard_index`. There are a few places in the
codebase where the shard id may be invalid - such as when a new block or
a new state witness is received. Just to be future safe I'm adding this
error handling in order to avoid panics during validation.

In most production code I added `?`. I most test code I added `unwrap`
or `expect`. In a few places I needed to manually convert the error with
`.map_err(Into::<EpochError>::into)?;`. Small refactoring of
`get_prev_shard_id` and `get_prev_shard_ids`.
  • Loading branch information
wacban authored Nov 4, 2024
1 parent cf6d2ef commit ad96419
Show file tree
Hide file tree
Showing 33 changed files with 175 additions and 140 deletions.
22 changes: 11 additions & 11 deletions chain/chain/src/chain.rs
Original file line number Diff line number Diff line change
Expand Up @@ -644,7 +644,7 @@ impl Chain {
genesis_protocol_version: ProtocolVersion,
congestion_info: Option<CongestionInfo>,
) -> Result<ChunkExtra, Error> {
let shard_index = shard_layout.get_shard_index(shard_id);
let shard_index = shard_layout.get_shard_index(shard_id)?;
let state_root = *get_genesis_state_roots(self.chain_store.store())?
.ok_or_else(|| Error::Other("genesis state roots do not exist in the db".to_owned()))?
.get(shard_index)
Expand Down Expand Up @@ -827,7 +827,7 @@ impl Chain {
&prev_block,
&shard_layout,
&shards_to_state_sync,
);
)?;
Ok(Some(state_sync_info))
}
}
Expand Down Expand Up @@ -1977,7 +1977,7 @@ impl Chain {
// TODO(#8055): this zip relies on the ordering of the apply_results.
// TODO(wacban): do the above todo
for (shard_id, apply_result) in apply_results.iter() {
let shard_index = shard_layout.get_shard_index(*shard_id);
let shard_index = shard_layout.get_shard_index(*shard_id)?;
if let Err(err) = apply_result {
if err.is_bad_data() {
let chunk = block.chunks()[shard_index].clone();
Expand Down Expand Up @@ -2578,7 +2578,7 @@ impl Chain {

let mut done = true;
for (shard_id, num_new_chunks) in num_new_chunks.iter_mut() {
let shard_index = shard_layout.get_shard_index(*shard_id);
let shard_index = shard_layout.get_shard_index(*shard_id)?;
let Some(included) = header.chunk_mask().get(shard_index) else {
return Err(Error::Other(format!(
"can't get shard {} in chunk mask for block {}",
Expand Down Expand Up @@ -2631,7 +2631,7 @@ impl Chain {
let shard_layout = self.epoch_manager.get_shard_layout(&sync_block_epoch_id)?;
let prev_epoch_id = sync_prev_block.header().epoch_id();
let prev_shard_layout = self.epoch_manager.get_shard_layout(&prev_epoch_id)?;
let prev_shard_index = prev_shard_layout.get_shard_index(shard_id);
let prev_shard_index = prev_shard_layout.get_shard_index(shard_id)?;

// Chunk header here is the same chunk header as at the `current` height.
let sync_prev_hash = sync_prev_block.hash();
Expand Down Expand Up @@ -2732,7 +2732,7 @@ impl Chain {
let ReceiptProof(receipts, shard_proof) = receipt_proof;
let ShardProof { from_shard_id, to_shard_id: _, proof } = shard_proof;
let receipts_hash = CryptoHash::hash_borsh(ReceiptList(shard_id, receipts));
let from_shard_index = prev_shard_layout.get_shard_index(*from_shard_id);
let from_shard_index = prev_shard_layout.get_shard_index(*from_shard_id)?;

let root_proof = block.chunks()[from_shard_index].prev_outgoing_receipts_root();
root_proofs_cur
Expand Down Expand Up @@ -2837,7 +2837,7 @@ impl Chain {
return Err(shard_id_out_of_bounds(shard_id));
}
let prev_block = self.get_block(header.prev_hash())?;
let shard_index = shard_layout.get_shard_index(shard_id);
let shard_index = shard_layout.get_shard_index(shard_id)?;
let state_root = prev_block
.chunks()
.get(shard_index)
Expand Down Expand Up @@ -3262,7 +3262,7 @@ impl Chain {
let epoch_id = block.header().epoch_id();
let shard_layout = self.epoch_manager.get_shard_layout(&epoch_id)?;
let shard_id = chunk.shard_id();
let shard_index = shard_layout.get_shard_index(shard_id);
let shard_index = shard_layout.get_shard_index(shard_id)?;

let chunk_proof = ChunkProofs {
block_header: borsh::to_vec(&block.header()).expect("Failed to serialize"),
Expand Down Expand Up @@ -3593,7 +3593,7 @@ impl Chain {
let epoch_id = block.header().epoch_id();
let shard_layout = self.epoch_manager.get_shard_layout(&epoch_id)?;
let shard_id = chunk_header.shard_id();
let shard_index = shard_layout.get_shard_index(shard_id);
let shard_index = shard_layout.get_shard_index(shard_id)?;
let prev_merkle_proofs =
Block::compute_chunk_headers_root(prev_block.chunks().iter_deprecated()).1;
let merkle_proofs = Block::compute_chunk_headers_root(block.chunks().iter_deprecated()).1;
Expand Down Expand Up @@ -3765,7 +3765,7 @@ impl Chain {
if err.is_bad_data() {
let epoch_id = block.header().epoch_id();
let shard_layout = self.epoch_manager.get_shard_layout(&epoch_id)?;
let shard_index = shard_layout.get_shard_index(shard_id);
let shard_index = shard_layout.get_shard_index(shard_id)?;

let chunk_header = block
.chunks()
Expand Down Expand Up @@ -4333,7 +4333,7 @@ impl Chain {
let block = self.get_block(&block_hash)?;
let chunks = block.chunks();
for &shard_id in shard_ids.iter() {
let shard_index = shard_layout.get_shard_index(shard_id);
let shard_index = shard_layout.get_shard_index(shard_id)?;
let chunk_header =
&chunks.get(shard_index).ok_or_else(|| Error::InvalidShardId(shard_id))?;
if chunk_header.height_included() == block.header().height() {
Expand Down
6 changes: 3 additions & 3 deletions chain/chain/src/runtime/tests.rs
Original file line number Diff line number Diff line change
Expand Up @@ -222,7 +222,7 @@ impl TestEnv {
let prev_block_hash = self.head.last_block_hash;
let epoch_id = self.epoch_manager.get_epoch_id_from_prev_block(&prev_block_hash).unwrap();
let shard_layout = self.epoch_manager.get_shard_layout(&epoch_id).unwrap();
let shard_index = shard_layout.get_shard_index(shard_id);
let shard_index = shard_layout.get_shard_index(shard_id).unwrap();
let state_root = self.state_roots[shard_index];
let gas_limit = u64::MAX;
let height = self.head.height + 1;
Expand Down Expand Up @@ -330,7 +330,7 @@ impl TestEnv {
let mut all_proposals = vec![];
let mut all_receipts = vec![];
for shard_id in shard_ids {
let shard_index = shard_layout.get_shard_index(shard_id);
let shard_index = shard_layout.get_shard_index(shard_id).unwrap();
let (state_root, proposals, receipts) = self.update_runtime(
shard_id,
new_hash,
Expand Down Expand Up @@ -402,7 +402,7 @@ impl TestEnv {
)
.unwrap();
let shard_layout = self.epoch_manager.get_shard_layout(&self.head.epoch_id).unwrap();
let shard_index = shard_layout.get_shard_index(shard_id);
let shard_index = shard_layout.get_shard_index(shard_id).unwrap();
let shard_uid = self.epoch_manager.shard_id_to_uid(shard_id, &self.head.epoch_id).unwrap();
self.runtime
.view_account(&shard_uid, self.state_roots[shard_index], account_id)
Expand Down
4 changes: 2 additions & 2 deletions chain/chain/src/stateless_validation/chunk_endorsement.rs
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ pub fn validate_chunk_endorsements_in_block(
// Validation for chunks in each shard
// The signatures from chunk validators for each shard must match the ordered_chunk_validators
let shard_id = chunk_header.shard_id();
let shard_index = shard_layout.get_shard_index(shard_id);
let shard_index = shard_layout.get_shard_index(shard_id)?;

let chunk_validator_assignments = epoch_manager.get_chunk_validator_assignments(
&epoch_id,
Expand Down Expand Up @@ -172,7 +172,7 @@ pub fn validate_chunk_endorsements_in_header(
}
let chunk_mask = header.chunk_mask();
for shard_id in shard_ids.into_iter() {
let shard_index = shard_layout.get_shard_index(shard_id);
let shard_index = shard_layout.get_shard_index(shard_id)?;
// For old chunks, we optimize the block and its header by not including the chunk endorsements and
// corresponding bitmaps. Thus, we expect that the bitmap is empty for shard with no new chunk.
if chunk_mask[shard_index] != (chunk_endorsements.len(shard_index).unwrap() > 0) {
Expand Down
2 changes: 1 addition & 1 deletion chain/chain/src/store/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -354,7 +354,7 @@ pub trait ChainStoreAccess {
let block_header = self.get_block_header(&candidate_hash)?;
let shard_layout = epoch_manager.get_shard_layout(block_header.epoch_id())?;
let mut shard_id = shard_id;
let mut shard_index = shard_layout.get_shard_index(shard_id);
let mut shard_index = shard_layout.get_shard_index(shard_id)?;
loop {
let block_header = self.get_block_header(&candidate_hash)?;
if *block_header
Expand Down
3 changes: 2 additions & 1 deletion chain/chain/src/store_validator/validate.rs
Original file line number Diff line number Diff line change
Expand Up @@ -595,7 +595,8 @@ pub(crate) fn trie_changes_chunk_extra_exists(

// 5. There should be ShardChunk with ShardId `shard_id`
let shard_id = shard_uid.shard_id();
let shard_index = shard_layout.get_shard_index(shard_id);
let shard_index =
unwrap_or_err!(shard_layout.get_shard_index(shard_id), "error getting shard index");
let chunks = block.chunks();
if let Some(chunk_header) = chunks.get(shard_index) {
// if the chunk is not a new chunk, skip the check
Expand Down
20 changes: 10 additions & 10 deletions chain/chain/src/test_utils/kv_runtime.rs
Original file line number Diff line number Diff line change
Expand Up @@ -474,7 +474,7 @@ impl EpochManagerAdapter for MockEpochManager {
let shard_layout = self.get_shard_layout(epoch_id)?;
let shard_id = account_id_to_shard_id(account_id, self.num_shards);
let shard_uid = ShardUId::from_shard_id_and_layout(shard_id, &shard_layout);
let shard_index = shard_layout.get_shard_index(shard_id);
let shard_index = shard_layout.get_shard_index(shard_id)?;
Ok(ShardUIdAndIndex { shard_uid, shard_index })
}

Expand All @@ -492,7 +492,7 @@ impl EpochManagerAdapter for MockEpochManager {
epoch_id: &EpochId,
) -> Result<ShardIndex, EpochError> {
let shard_layout = self.get_shard_layout(epoch_id)?;
Ok(shard_layout.get_shard_index(shard_id))
Ok(shard_layout.get_shard_index(shard_id)?)
}

fn get_block_info(&self, _hash: &CryptoHash) -> Result<Arc<BlockInfo>, EpochError> {
Expand Down Expand Up @@ -626,7 +626,7 @@ impl EpochManagerAdapter for MockEpochManager {
// This is not correct if there was a resharding event in between
// the previous and current block.
let prev_shard_id = shard_id;
let prev_shard_index = shard_layout.get_shard_index(prev_shard_id);
let prev_shard_index = shard_layout.get_shard_index(prev_shard_id)?;
prev_shard_ids.push((prev_shard_id, prev_shard_index));
}

Expand All @@ -642,7 +642,7 @@ impl EpochManagerAdapter for MockEpochManager {
// This is not correct if there was a resharding event in between
// the previous and current block.
let prev_shard_id = shard_id;
let prev_shard_index = shard_layout.get_shard_index(prev_shard_id);
let prev_shard_index = shard_layout.get_shard_index(prev_shard_id)?;
Ok((prev_shard_id, prev_shard_index))
}

Expand Down Expand Up @@ -753,7 +753,7 @@ impl EpochManagerAdapter for MockEpochManager {
) -> Result<Vec<AccountId>, EpochError> {
let valset = self.get_valset_for_epoch(epoch_id)?;
let shard_layout = self.get_shard_layout(epoch_id)?;
let shard_index = shard_layout.get_shard_index(shard_id);
let shard_index = shard_layout.get_shard_index(shard_id)?;
let chunk_producers = self.get_chunk_producers(valset, shard_index);
Ok(chunk_producers.into_iter().map(|vs| vs.take_account_id()).collect())
}
Expand Down Expand Up @@ -787,7 +787,7 @@ impl EpochManagerAdapter for MockEpochManager {
) -> Result<AccountId, EpochError> {
let valset = self.get_valset_for_epoch(epoch_id)?;
let shard_layout = self.get_shard_layout(epoch_id)?;
let shard_index = shard_layout.get_shard_index(shard_id);
let shard_index = shard_layout.get_shard_index(shard_id)?;
let chunk_producers = self.get_chunk_producers(valset, shard_index);
let index = (shard_index + height as usize + 1) % chunk_producers.len();
Ok(chunk_producers[index].account_id().clone())
Expand Down Expand Up @@ -1056,7 +1056,7 @@ impl EpochManagerAdapter for MockEpochManager {
// the calling function.
let epoch_valset = self.get_valset_for_epoch(&epoch_id).unwrap();
let shard_layout = self.get_shard_layout(&epoch_id)?;
let shard_index = shard_layout.get_shard_index(shard_id);
let shard_index = shard_layout.get_shard_index(shard_id)?;
let chunk_producers = self.get_chunk_producers(epoch_valset, shard_index);
for validator in chunk_producers {
if validator.account_id() == account_id {
Expand All @@ -1077,7 +1077,7 @@ impl EpochManagerAdapter for MockEpochManager {
// the calling function.
let epoch_valset = self.get_epoch_and_valset(*parent_hash).unwrap();
let shard_layout = self.get_shard_layout_from_prev_block(parent_hash)?;
let shard_index = shard_layout.get_shard_index(shard_id);
let shard_index = shard_layout.get_shard_index(shard_id)?;
let chunk_producers = self.get_chunk_producers(epoch_valset.1, shard_index);
for validator in chunk_producers {
if validator.account_id() == account_id {
Expand All @@ -1098,7 +1098,7 @@ impl EpochManagerAdapter for MockEpochManager {
// the calling function.
let epoch_valset = self.get_epoch_and_valset(*parent_hash).unwrap();
let shard_layout = self.get_shard_layout_from_prev_block(parent_hash)?;
let shard_index = shard_layout.get_shard_index(shard_id);
let shard_index = shard_layout.get_shard_index(shard_id)?;
let chunk_producers = self.get_chunk_producers(
(epoch_valset.1 + 1) % self.validators_by_valset.len(),
shard_index,
Expand Down Expand Up @@ -1153,7 +1153,7 @@ impl EpochManagerAdapter for MockEpochManager {
) -> Result<AccountId, EpochError> {
let valset = self.get_valset_for_epoch(epoch_id)?;
let shard_layout = self.get_shard_layout(epoch_id)?;
let shard_index = shard_layout.get_shard_index(shard_id);
let shard_index = shard_layout.get_shard_index(shard_id)?;
let chunk_producers = self.get_chunk_producers(valset, shard_index);
let index = rand::thread_rng().gen_range(0..chunk_producers.len());
Ok(chunk_producers[index].account_id().clone())
Expand Down
9 changes: 6 additions & 3 deletions chain/client/src/client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -277,7 +277,8 @@ impl NewChunkTracker {

let mut done = true;
for (shard_id, num_new_chunks) in self.num_new_chunks.iter_mut() {
let shard_index = shard_layout.get_shard_index(*shard_id);
let shard_index =
shard_layout.get_shard_index(*shard_id).map_err(Into::<EpochError>::into)?;
let Some(included) = header.chunk_mask().get(shard_index) else {
return Err(Error::Other(format!(
"can't get shard {} in chunk mask for block {}",
Expand Down Expand Up @@ -840,7 +841,8 @@ impl Client {
// Collect new chunk headers and endorsements.
let shard_layout = self.epoch_manager.get_shard_layout(&epoch_id)?;
for (shard_id, chunk_hash) in new_chunks {
let shard_index = shard_layout.get_shard_index(shard_id);
let shard_index =
shard_layout.get_shard_index(shard_id).map_err(Into::<EpochError>::into)?;
let (mut chunk_header, chunk_endorsement) =
self.chunk_inclusion_tracker.get_chunk_header_and_endorsements(&chunk_hash)?;
*chunk_header.height_included_mut() = height;
Expand Down Expand Up @@ -1570,7 +1572,8 @@ impl Client {
.expect("Could not obtain shard layout");

let shard_id = partial_chunk.shard_id();
let shard_index = shard_layout.get_shard_index(shard_id);
let shard_index =
shard_layout.get_shard_index(shard_id).expect("Could not obtain shard index");
self.block_production_info
.record_chunk_collected(partial_chunk.height_created(), shard_index);

Expand Down
2 changes: 1 addition & 1 deletion chain/client/src/info.rs
Original file line number Diff line number Diff line change
Expand Up @@ -251,7 +251,7 @@ impl InfoHelper {
});

for shard_id in shard_ids {
let shard_index = shard_layout.get_shard_index(shard_id);
let shard_index = shard_layout.get_shard_index(shard_id).unwrap();
let mut stake_per_cp = HashMap::<ValidatorId, Balance>::new();
stake_sum = 0;
let chunk_producers_settlement = &epoch_info.chunk_producers_settlement();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -404,7 +404,7 @@ impl Client {
self.epoch_manager.get_shard_layout(from_block.header().epoch_id())?;
for proof in receipt_proof_response.1.iter() {
let from_shard_id = proof.1.from_shard_id;
let from_shard_index = shard_layout.get_shard_index(from_shard_id);
let from_shard_index = shard_layout.get_shard_index(from_shard_id)?;
let from_chunk_hash = from_block
.chunks()
.get(from_shard_index)
Expand Down
4 changes: 2 additions & 2 deletions chain/client/src/test_utils/test_env.rs
Original file line number Diff line number Diff line change
Expand Up @@ -525,7 +525,7 @@ impl TestEnv {
let shard_id =
client.epoch_manager.account_id_to_shard_id(&account_id, &head.epoch_id).unwrap();
let shard_layout = client.epoch_manager.get_shard_layout(&head.epoch_id).unwrap();
let shard_index = shard_layout.get_shard_index(shard_id);
let shard_index = shard_layout.get_shard_index(shard_id).unwrap();
let shard_uid = client.epoch_manager.shard_id_to_uid(shard_id, &head.epoch_id).unwrap();
let last_chunk_header = &last_block.chunks()[shard_index];

Expand Down Expand Up @@ -585,7 +585,7 @@ impl TestEnv {
client.epoch_manager.account_id_to_shard_id(&account_id, &head.epoch_id).unwrap();
let shard_uid = client.epoch_manager.shard_id_to_uid(shard_id, &head.epoch_id).unwrap();
let shard_layout = client.epoch_manager.get_shard_layout(&head.epoch_id).unwrap();
let shard_index = shard_layout.get_shard_index(shard_id);
let shard_index = shard_layout.get_shard_index(shard_id).unwrap();
let last_chunk_header = &last_block.chunks()[shard_index];
let response = client
.runtime_adapter
Expand Down
2 changes: 1 addition & 1 deletion chain/client/src/test_utils/test_loop.rs
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ where
client.epoch_manager.account_id_to_shard_id(&account_id, &head.epoch_id).unwrap();
let shard_uid = client.epoch_manager.shard_id_to_uid(shard_id, &head.epoch_id).unwrap();
let shard_layout = client.epoch_manager.get_shard_layout(&head.epoch_id).unwrap();
let shard_index = shard_layout.get_shard_index(shard_id);
let shard_index = shard_layout.get_shard_index(shard_id).unwrap();
let last_chunk_header = &last_block.chunks()[shard_index];

client
Expand Down
7 changes: 5 additions & 2 deletions chain/client/src/view_client_actor.rs
Original file line number Diff line number Diff line change
Expand Up @@ -760,7 +760,7 @@ fn get_chunk_from_block(
) -> Result<ShardChunk, near_chain::Error> {
let epoch_id = block.header().epoch_id();
let shard_layout = chain.epoch_manager.get_shard_layout(epoch_id)?;
let shard_index = shard_layout.get_shard_index(shard_id);
let shard_index = shard_layout.get_shard_index(shard_id)?;
let chunk_header = block
.chunks()
.get(shard_index)
Expand Down Expand Up @@ -1092,7 +1092,10 @@ impl Handler<GetExecutionOutcome> for ViewClientActorInner {
.epoch_manager
.account_id_to_shard_id(&account_id, &epoch_id)
.into_chain_error()?;
let target_shard_index = shard_layout.get_shard_index(target_shard_id);
let target_shard_index = shard_layout
.get_shard_index(target_shard_id)
.map_err(Into::into)
.into_chain_error()?;
let res = self.chain.get_next_block_hash_with_new_chunk(
&outcome_proof.block_hash,
target_shard_id,
Expand Down
Loading

0 comments on commit ad96419

Please sign in to comment.