diff --git a/Cargo.lock b/Cargo.lock index f007ab848a..8dd2b5c68f 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -3709,7 +3709,7 @@ dependencies = [ [[package]] name = "mithril-common" -version = "0.4.73" +version = "0.4.74" dependencies = [ "anyhow", "async-trait", diff --git a/mithril-common/Cargo.toml b/mithril-common/Cargo.toml index 9a2acf1927..4140098b3b 100644 --- a/mithril-common/Cargo.toml +++ b/mithril-common/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "mithril-common" -version = "0.4.73" +version = "0.4.74" description = "Common types, interfaces, and utilities for Mithril nodes." authors = { workspace = true } edition = { workspace = true } diff --git a/mithril-common/src/certificate_chain/certificate_retriever.rs b/mithril-common/src/certificate_chain/certificate_retriever.rs index 77866c2bf1..53e34f1269 100644 --- a/mithril-common/src/certificate_chain/certificate_retriever.rs +++ b/mithril-common/src/certificate_chain/certificate_retriever.rs @@ -5,16 +5,13 @@ use thiserror::Error; use crate::{entities::Certificate, StdError}; -#[cfg(test)] -use mockall::automock; - /// [CertificateRetriever] related errors. #[derive(Debug, Error)] #[error("Error when retrieving certificate")] pub struct CertificateRetrieverError(#[source] pub StdError); /// CertificateRetriever is in charge of retrieving a [Certificate] given its hash -#[cfg_attr(test, automock)] +#[cfg_attr(test, mockall::automock)] #[cfg_attr(target_family = "wasm", async_trait(?Send))] #[cfg_attr(not(target_family = "wasm"), async_trait)] pub trait CertificateRetriever: Sync + Send { diff --git a/mithril-common/src/certificate_chain/certificate_verifier.rs b/mithril-common/src/certificate_chain/certificate_verifier.rs index da25e34720..f7389defeb 100644 --- a/mithril-common/src/certificate_chain/certificate_verifier.rs +++ b/mithril-common/src/certificate_chain/certificate_verifier.rs @@ -306,7 +306,7 @@ mod tests { use super::CertificateRetriever; use super::*; - use crate::certificate_chain::CertificateRetrieverError; + use crate::certificate_chain::{CertificateRetrieverError, FakeCertificaterRetriever}; use crate::crypto_helper::{tests_setup::*, ProtocolClerk}; use crate::test_utils::{MithrilFixtureBuilder, TestLogger}; @@ -527,24 +527,19 @@ mod tests { let certificates_per_epoch = 2; let (fake_certificates, genesis_verifier) = setup_certificate_chain(total_certificates, certificates_per_epoch); - let mut mock_certificate_retriever = MockCertificateRetrieverImpl::new(); + let certificate_retriever = + FakeCertificaterRetriever::from_certificates(&fake_certificates); + let verifier = + MithrilCertificateVerifier::new(TestLogger::stdout(), Arc::new(certificate_retriever)); let certificate_to_verify = fake_certificates[0].clone(); - for fake_certificate in fake_certificates.into_iter().skip(1) { - mock_certificate_retriever - .expect_get_certificate_details() - .returning(move |_| Ok(fake_certificate.clone())) - .times(1); - } - let verifier = MithrilCertificateVerifier::new( - TestLogger::stdout(), - Arc::new(mock_certificate_retriever), - ); + let verify = verifier .verify_certificate_chain( certificate_to_verify, &genesis_verifier.to_verification_key(), ) .await; + verify.expect("unexpected error"); } @@ -555,23 +550,13 @@ mod tests { let (mut fake_certificates, genesis_verifier) = setup_certificate_chain(total_certificates, certificates_per_epoch); let index_certificate_fail = (total_certificates / 2) as usize; - fake_certificates[index_certificate_fail].hash = "tampered-hash".to_string(); - let mut mock_certificate_retriever = MockCertificateRetrieverImpl::new(); + fake_certificates[index_certificate_fail].signed_message = "tampered-message".to_string(); + let certificate_retriever = + FakeCertificaterRetriever::from_certificates(&fake_certificates); + let verifier = + MithrilCertificateVerifier::new(TestLogger::stdout(), Arc::new(certificate_retriever)); let certificate_to_verify = fake_certificates[0].clone(); - for fake_certificate in fake_certificates - .into_iter() - .skip(1) - .take(index_certificate_fail) - { - mock_certificate_retriever - .expect_get_certificate_details() - .returning(move |_| Ok(fake_certificate.clone())) - .times(1); - } - let verifier = MithrilCertificateVerifier::new( - TestLogger::stdout(), - Arc::new(mock_certificate_retriever), - ); + let error = verifier .verify_certificate_chain( certificate_to_verify, @@ -584,10 +569,7 @@ mod tests { .expect("Can not downcast to `CertificateVerifierError`."); assert!( - matches!( - error, - CertificateVerifierError::CertificateChainPreviousHashUnmatch - ), + matches!(error, CertificateVerifierError::CertificateHashUnmatch), "unexpected error type: {error:?}" ); } diff --git a/mithril-common/src/certificate_chain/fake_certificate_retriever.rs b/mithril-common/src/certificate_chain/fake_certificate_retriever.rs new file mode 100644 index 0000000000..d0ca31d21a --- /dev/null +++ b/mithril-common/src/certificate_chain/fake_certificate_retriever.rs @@ -0,0 +1,77 @@ +//! A module used for a fake implementation of a certificate chain retriever +//! + +use anyhow::anyhow; +use async_trait::async_trait; +use std::collections::HashMap; +use tokio::sync::RwLock; + +use crate::entities::Certificate; + +use super::{CertificateRetriever, CertificateRetrieverError}; + +/// A fake [CertificateRetriever] that returns a [Certificate] given its hash +pub struct FakeCertificaterRetriever { + certificates_map: RwLock>, +} + +impl FakeCertificaterRetriever { + /// Create a new [FakeCertificaterRetriever] + pub fn from_certificates(certificates: &[Certificate]) -> Self { + let certificates_map = certificates + .iter() + .map(|certificate| (certificate.hash.clone(), certificate.clone())) + .collect::>(); + let certificates_map = RwLock::new(certificates_map); + + Self { certificates_map } + } +} + +#[async_trait] +impl CertificateRetriever for FakeCertificaterRetriever { + async fn get_certificate_details( + &self, + certificate_hash: &str, + ) -> Result { + let certificates_map = self.certificates_map.read().await; + certificates_map + .get(certificate_hash) + .cloned() + .ok_or_else(|| CertificateRetrieverError(anyhow!("Certificate not found"))) + } +} + +#[cfg(test)] +mod tests { + use crate::test_utils::fake_data; + + use super::*; + + #[tokio::test] + async fn fake_certificate_retriever_retrieves_existing_certificate() { + let certificate = fake_data::certificate("certificate-hash-123".to_string()); + let certificate_hash = certificate.hash.clone(); + let certificate_retriever = + FakeCertificaterRetriever::from_certificates(&[certificate.clone()]); + + let retrieved_certificate = certificate_retriever + .get_certificate_details(&certificate_hash) + .await + .expect("Should retrieve certificate"); + + assert_eq!(retrieved_certificate, certificate); + } + + #[tokio::test] + async fn test_fake_certificate_fails_retrieving_unknow_certificate() { + let certificate = fake_data::certificate("certificate-hash-123".to_string()); + let certificate_retriever = FakeCertificaterRetriever::from_certificates(&[certificate]); + + let retrieved_certificate = certificate_retriever + .get_certificate_details("certificate-hash-not-found") + .await; + + retrieved_certificate.expect_err("get_certificate_details shoudl fail"); + } +} diff --git a/mithril-common/src/certificate_chain/mod.rs b/mithril-common/src/certificate_chain/mod.rs index 07ce00ab5c..dda71e61ff 100644 --- a/mithril-common/src/certificate_chain/mod.rs +++ b/mithril-common/src/certificate_chain/mod.rs @@ -3,9 +3,16 @@ mod certificate_genesis; mod certificate_retriever; mod certificate_verifier; +cfg_test_tools! { + mod fake_certificate_retriever; +} pub use certificate_genesis::CertificateGenesisProducer; pub use certificate_retriever::{CertificateRetriever, CertificateRetrieverError}; pub use certificate_verifier::{ CertificateVerifier, CertificateVerifierError, MithrilCertificateVerifier, }; + +cfg_test_tools! { + pub use fake_certificate_retriever::FakeCertificaterRetriever; +} diff --git a/mithril-common/src/test_utils/certificate_chain_builder.rs b/mithril-common/src/test_utils/certificate_chain_builder.rs index 832b316da5..b1b39b25cc 100644 --- a/mithril-common/src/test_utils/certificate_chain_builder.rs +++ b/mithril-common/src/test_utils/certificate_chain_builder.rs @@ -419,29 +419,58 @@ impl<'a> CertificateChainBuilder<'a> { certificate } + fn update_certificate_previous_hash( + &self, + certificate: Certificate, + previous_certificate: Option<&Certificate>, + ) -> Certificate { + let mut certificate = certificate; + certificate.previous_hash = previous_certificate + .map(|c| c.hash.to_string()) + .unwrap_or_default(); + certificate.hash = certificate.compute_hash(); + + certificate + } + + fn fetch_previous_certificate_from_chain<'b>( + &self, + certificate: &Certificate, + certificates_chained: &'b [Certificate], + ) -> Option<&'b Certificate> { + let is_certificate_first_of_epoch = certificates_chained + .last() + .map(|c| c.epoch != certificate.epoch) + .unwrap_or(true); + + certificates_chained + .iter() + .rev() + .filter(|c| { + if is_certificate_first_of_epoch { + // The previous certificate of the first certificate of an epoch + // is the first certificate of the previous epoch + c.epoch == certificate.epoch.previous().unwrap() + } else { + // The previous certificate of not the first certificate of an epoch + // is the first certificate of the epoch + c.epoch == certificate.epoch + } + }) + .last() + } + // Returns the chained certificates in reverse order // The latest certificate of the chain is the first in the vector fn compute_chained_certificates(&self, certificates: Vec) -> Vec { - fn update_certificate_previous_hash( - certificate: Certificate, - previous_certificate: Option<&Certificate>, - ) -> Certificate { - let mut certificate = certificate; - certificate.previous_hash = previous_certificate - .map(|c| c.hash.to_string()) - .unwrap_or_default(); - certificate.hash = certificate.compute_hash(); - - certificate - } - let mut certificates_chained: Vec = certificates .into_iter() .fold(Vec::new(), |mut certificates_chained, certificate| { - let previous_certificate_maybe = certificates_chained.last(); - let certificate = - update_certificate_previous_hash(certificate, previous_certificate_maybe); + let previous_certificate_maybe = self + .fetch_previous_certificate_from_chain(&certificate, &certificates_chained); + let certificate = self + .update_certificate_previous_hash(certificate, previous_certificate_maybe); certificates_chained.push(certificate); certificates_chained @@ -760,32 +789,59 @@ mod test { #[test] fn builds_certificate_chain_correctly_chained() { - let certificates = vec![ - Certificate { - epoch: Epoch(1), - ..fake_data::certificate("cert-1".to_string()) - }, - Certificate { - epoch: Epoch(2), - ..fake_data::certificate("cert-2".to_string()) - }, + fn create_fake_certificate(epoch: Epoch, index_in_epoch: u64) -> Certificate { Certificate { - epoch: Epoch(3), - ..fake_data::certificate("cert-3".to_string()) - }, + epoch, + signed_message: format!("certificate-{}-{index_in_epoch}", *epoch), + ..fake_data::certificate("cert-fake".to_string()) + } + } + + let certificates = vec![ + create_fake_certificate(Epoch(1), 1), + create_fake_certificate(Epoch(2), 1), + create_fake_certificate(Epoch(2), 2), + create_fake_certificate(Epoch(3), 1), + create_fake_certificate(Epoch(4), 1), + create_fake_certificate(Epoch(4), 2), + create_fake_certificate(Epoch(4), 3), ]; - let certificates_chained = + let mut certificates_chained = CertificateChainBuilder::default().compute_chained_certificates(certificates); + certificates_chained.reverse(); - assert_eq!("", certificates_chained[2].previous_hash); + let certificate_chained_1_1 = &certificates_chained[0]; + let certificate_chained_2_1 = &certificates_chained[1]; + let certificate_chained_2_2 = &certificates_chained[2]; + let certificate_chained_3_1 = &certificates_chained[3]; + let certificate_chained_4_1 = &certificates_chained[4]; + let certificate_chained_4_2 = &certificates_chained[5]; + let certificate_chained_4_3 = &certificates_chained[6]; + assert_eq!("", certificate_chained_1_1.previous_hash); + assert_eq!( + certificate_chained_2_1.previous_hash, + certificate_chained_1_1.hash + ); + assert_eq!( + certificate_chained_2_2.previous_hash, + certificate_chained_2_1.hash + ); + assert_eq!( + certificate_chained_3_1.previous_hash, + certificate_chained_2_1.hash + ); + assert_eq!( + certificate_chained_4_1.previous_hash, + certificate_chained_3_1.hash + ); assert_eq!( - certificates_chained[2].hash, - certificates_chained[1].previous_hash + certificate_chained_4_2.previous_hash, + certificate_chained_4_1.hash ); assert_eq!( - certificates_chained[1].hash, - certificates_chained[0].previous_hash + certificate_chained_4_3.previous_hash, + certificate_chained_4_1.hash ); }