diff --git a/consensus/Cargo.toml b/consensus/Cargo.toml index d0668bf..34a3085 100644 --- a/consensus/Cargo.toml +++ b/consensus/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "aleph-bft" -version = "0.33.1" +version = "0.33.2" edition = "2021" authors = ["Cardinal Cryptography"] categories = ["algorithms", "data-structures", "cryptography", "database"] diff --git a/consensus/src/testing/consensus.rs b/consensus/src/testing/consensus.rs index 6867188..ee68c73 100644 --- a/consensus/src/testing/consensus.rs +++ b/consensus/src/testing/consensus.rs @@ -3,7 +3,7 @@ use crate::{ runway::{NotificationIn, NotificationOut}, testing::{complete_oneshot, gen_config, gen_delay_config, init_log}, units::{ControlHash, PreUnit, Unit, UnitCoord}, - Hasher, NodeIndex, SpawnHandle, Terminator, + Hasher, NodeCount, NodeIndex, NodeMap, SpawnHandle, Terminator, }; use aleph_bft_mock::{Hasher64, Spawner}; use codec::Encode; @@ -221,8 +221,20 @@ async fn catches_wrong_control_hash() { Terminator::create_root(exit_rx, "AlephBFT-consensus"), ), ); - let control_hash = ControlHash::new(&(vec![None; n_nodes]).into()); - let bad_pu = PreUnit::::new(1.into(), 0, control_hash); + let empty_control_hash = ControlHash::new(&(vec![None; n_nodes]).into()); + let other_initial_units: Vec<_> = (1..n_nodes) + .map(NodeIndex) + .map(|creator| PreUnit::::new(creator, 0, empty_control_hash.clone())) + .map(|pu| Unit::new(pu, rand::random())) + .collect(); + let _ = tx_in + .send(NotificationIn::NewUnits(other_initial_units.clone())) + .await; + let mut parent_hashes = NodeMap::with_size(NodeCount(n_nodes)); + for (id, unit) in other_initial_units.into_iter().enumerate() { + parent_hashes.insert(NodeIndex(id + 1), unit.hash()); + } + let bad_pu = PreUnit::::new(1.into(), 1, ControlHash::new(&parent_hashes)); let bad_control_hash: ::Hash = [0, 1, 0, 1, 0, 1, 0, 1]; assert!( bad_control_hash != bad_pu.control_hash().combined_hash, @@ -231,14 +243,14 @@ async fn catches_wrong_control_hash() { let mut control_hash = bad_pu.control_hash().clone(); control_hash.combined_hash = bad_control_hash; let bad_pu = PreUnit::new(bad_pu.creator(), bad_pu.round(), control_hash); - let bad_hash: ::Hash = [0, 1, 0, 1, 0, 1, 0, 1]; - let bad_unit = Unit::new(bad_pu, bad_hash); + let some_hash: ::Hash = [0, 1, 0, 1, 0, 1, 0, 1]; + let bad_unit = Unit::new(bad_pu, some_hash); let _ = tx_in.send(NotificationIn::NewUnits(vec![bad_unit])).await; loop { let notification = rx_out.next().await.unwrap(); trace!(target: "consensus-test", "notification {:?}", notification); if let NotificationOut::WrongControlHash(h) = notification { - assert_eq!(h, bad_hash, "Expected notification for our bad unit."); + assert_eq!(h, some_hash, "Expected notification for our bad unit."); break; } } diff --git a/consensus/src/units/validator.rs b/consensus/src/units/validator.rs index 56d8cfe..8d4d7f4 100644 --- a/consensus/src/units/validator.rs +++ b/consensus/src/units/validator.rs @@ -1,6 +1,6 @@ use crate::{ - units::{FullUnit, PreUnit, SignedUnit, UncheckedSignedUnit}, - Data, Hasher, Keychain, NodeCount, Round, SessionId, Signature, SignatureError, + units::{ControlHash, FullUnit, PreUnit, SignedUnit, UncheckedSignedUnit}, + Data, Hasher, Keychain, NodeCount, NodeMap, Round, SessionId, Signature, SignatureError, }; use std::{ fmt::{Display, Formatter, Result as FmtResult}, @@ -15,6 +15,7 @@ pub enum ValidationError { RoundTooHigh(FullUnit), WrongNumberOfMembers(PreUnit), RoundZeroWithParents(PreUnit), + RoundZeroBadControlHash(PreUnit), NotEnoughParents(PreUnit), NotDescendantOfPreviousUnit(PreUnit), } @@ -33,6 +34,9 @@ impl Display for ValidationError { pu ), RoundZeroWithParents(pu) => write!(f, "zero round unit with parents: {:?}", pu), + RoundZeroBadControlHash(pu) => { + write!(f, "zero round unit with wrong control hash: {:?}", pu) + } NotEnoughParents(pu) => write!( f, "nonzero round unit with only {:?} parents: {:?}", @@ -98,26 +102,38 @@ impl Validator { &self, su: SignedUnit, ) -> Result { - // NOTE: at this point we cannot validate correctness of the control hash, in principle it could be - // just a random hash, but we still would not be able to deduce that by looking at the unit only. let pre_unit = su.as_signable().as_pre_unit(); - if pre_unit.n_members() != self.keychain.node_count() { + let n_members = pre_unit.n_members(); + if n_members != self.keychain.node_count() { return Err(ValidationError::WrongNumberOfMembers(pre_unit.clone())); } let round = pre_unit.round(); let n_parents = pre_unit.n_parents(); - if round == 0 && n_parents > NodeCount(0) { - return Err(ValidationError::RoundZeroWithParents(pre_unit.clone())); - } - let threshold = self.threshold; - if round > 0 && n_parents < threshold { - return Err(ValidationError::NotEnoughParents(pre_unit.clone())); - } - let control_hash = &pre_unit.control_hash(); - if round > 0 && !control_hash.parents_mask[pre_unit.creator()] { - return Err(ValidationError::NotDescendantOfPreviousUnit( - pre_unit.clone(), - )); + match round { + 0 => { + if n_parents > NodeCount(0) { + return Err(ValidationError::RoundZeroWithParents(pre_unit.clone())); + } + if pre_unit.control_hash().combined_hash + != ControlHash::::combine_hashes(&NodeMap::with_size(n_members)) + { + return Err(ValidationError::RoundZeroBadControlHash(pre_unit.clone())); + } + } + // NOTE: at this point we cannot validate correctness of the control hash, in principle it could be + // just a random hash, but we still would not be able to deduce that by looking at the unit only. + _ => { + let threshold = self.threshold; + if n_parents < threshold { + return Err(ValidationError::NotEnoughParents(pre_unit.clone())); + } + let control_hash = &pre_unit.control_hash(); + if !control_hash.parents_mask[pre_unit.creator()] { + return Err(ValidationError::NotDescendantOfPreviousUnit( + pre_unit.clone(), + )); + } + } } Ok(su) } @@ -128,7 +144,9 @@ mod tests { use super::{ValidationError::*, Validator as GenericValidator}; use crate::{ creation::Creator as GenericCreator, - units::{create_units, creator_set, preunit_to_unchecked_signed_unit, preunit_to_unit}, + units::{ + create_units, creator_set, preunit_to_unchecked_signed_unit, preunit_to_unit, PreUnit, + }, NodeCount, NodeIndex, }; use aleph_bft_mock::{Hasher64, Keychain}; @@ -157,6 +175,33 @@ mod tests { assert_eq!(unchecked_unit, checked_unit.into()); } + #[test] + fn detects_wrong_initial_control_hash() { + let n_members = NodeCount(7); + let threshold = NodeCount(5); + let creator_id = NodeIndex(0); + let session_id = 0; + let round = 0; + let max_round = 2; + let creator = Creator::new(creator_id, n_members); + let keychain = Keychain::new(n_members, creator_id); + let validator = Validator::new(session_id, keychain, max_round, threshold); + let (preunit, _) = creator + .create_unit(round) + .expect("Creation should succeed."); + let mut control_hash = preunit.control_hash().clone(); + control_hash.combined_hash = [0, 1, 0, 1, 0, 1, 0, 1]; + let preunit = PreUnit::new(preunit.creator(), preunit.round(), control_hash); + let unchecked_unit = + preunit_to_unchecked_signed_unit(preunit.clone(), session_id, &keychain); + let other_preunit = match validator.validate_unit(unchecked_unit.clone()) { + Ok(_) => panic!("Validated bad unit."), + Err(RoundZeroBadControlHash(unit)) => unit, + Err(e) => panic!("Unexpected error from validator: {:?}", e), + }; + assert_eq!(other_preunit, preunit); + } + #[test] fn detects_wrong_session_id() { let n_members = NodeCount(7);