From df7615c6d2441128ea7f7796c44ed2810a7de497 Mon Sep 17 00:00:00 2001 From: Bogdan Opanchuk Date: Wed, 15 Nov 2023 15:10:32 -0800 Subject: [PATCH] Add message preprocessing to eliminate some duplicate checks --- synedrion/src/sessions/broadcast.rs | 6 + synedrion/src/sessions/states.rs | 198 +++++++++++++++----------- synedrion/src/sessions/type_erased.rs | 18 +++ synedrion/src/tools/collections.rs | 18 ++- synedrion/tests/sessions.rs | 35 ++--- 5 files changed, 173 insertions(+), 102 deletions(-) diff --git a/synedrion/src/sessions/broadcast.rs b/synedrion/src/sessions/broadcast.rs index b1bbae52..af94e6d4 100644 --- a/synedrion/src/sessions/broadcast.rs +++ b/synedrion/src/sessions/broadcast.rs @@ -98,6 +98,12 @@ impl BcConsensusAccum { } } + pub fn contains(&self, party_idx: PartyIdx) -> bool { + self.received_echo_from + .contains(party_idx.as_usize()) + .unwrap() + } + pub fn add_echo_received(&mut self, from: PartyIdx) -> Option<()> { self.received_echo_from.insert(from.as_usize(), ()) } diff --git a/synedrion/src/sessions/states.rs b/synedrion/src/sessions/states.rs index 8b9b7ca0..69c46adb 100644 --- a/synedrion/src/sessions/states.rs +++ b/synedrion/src/sessions/states.rs @@ -1,7 +1,7 @@ use alloc::boxed::Box; use alloc::collections::BTreeMap; use alloc::format; -use alloc::vec::Vec; +use alloc::{vec, vec::Vec}; use core::fmt::Debug; use rand_core::CryptoRngCore; @@ -46,13 +46,12 @@ pub struct Session { enum MessageFor { ThisRound, NextRound, - OutOfOrder, } fn route_message_normal( round: &dyn DynFinalizable, message: &SignedMessage, -) -> MessageFor { +) -> Result { let this_round = round.round_num(); let next_round = round.next_round_num(); let requires_bc = round.requires_broadcast_consensus(); @@ -61,7 +60,7 @@ fn route_message_normal( let message_bc = message.message_type() == MessageType::BroadcastConsensus; if message_round == this_round && !message_bc { - return MessageFor::ThisRound; + return Ok(MessageFor::ThisRound); } let for_next_round = @@ -71,29 +70,29 @@ fn route_message_normal( (requires_bc && message_round == this_round && message_bc); if for_next_round { - return MessageFor::NextRound; + return Ok(MessageFor::NextRound); } - MessageFor::OutOfOrder + Err(RemoteErrorEnum::OutOfOrderMessage) } fn route_message_bc( next_round: &dyn DynFinalizable, message: &SignedMessage, -) -> MessageFor { +) -> Result { let next_round = next_round.round_num(); let message_round = message.round(); let message_bc = message.message_type() == MessageType::BroadcastConsensus; if message_round == next_round - 1 && message_bc { - return MessageFor::ThisRound; + return Ok(MessageFor::ThisRound); } if message_round == next_round && !message_bc { - return MessageFor::NextRound; + return Ok(MessageFor::NextRound); } - MessageFor::OutOfOrder + Err(RemoteErrorEnum::OutOfOrderMessage) } fn wrap_receive_result( @@ -129,7 +128,7 @@ pub enum FinalizeOutcome { /// The new session object. session: Session, /// The messages for the new round received during the previous round. - cached_messages: Vec<(Verifier, SignedMessage)>, + cached_messages: Vec>, }, } @@ -333,12 +332,31 @@ where } } - /// Process a received message from another party. - pub fn verify_message( + fn route_message( &self, from: &Verifier, + message: &SignedMessage, + ) -> Result> { + let message_for = match &self.tp { + SessionType::Normal(round) => route_message_normal(round.as_ref(), message), + SessionType::Bc { next_round, .. } => route_message_bc(next_round.as_ref(), message), + }; + + message_for.map_err(|err| { + Error::Remote(RemoteError { + party: from.clone(), + error: err, + }) + }) + } + + /// Perform quick checks on a received message. + pub fn preprocess_message( + &self, + accum: &mut RoundAccumulator, + from: &Verifier, message: SignedMessage, - ) -> Result, Error> { + ) -> Result>, Error> { // This is an unprovable fault (may be a replay attack) if message.session_id() != &self.context.session_id { return Err(Error::Remote(RemoteError { @@ -347,34 +365,8 @@ where })); } - let message_for = match &self.tp { - SessionType::Normal(round) => route_message_normal(round.as_ref(), &message), - SessionType::Bc { next_round, .. } => route_message_bc(next_round.as_ref(), &message), - }; - - let from_idx = self.context.verifier_to_idx[from]; - - match message_for { - MessageFor::ThisRound => self.verify_message_inner(from, message), - // TODO: should we cache the verified or the unverified message? - MessageFor::NextRound => Ok(ProcessedMessage { - from: from.clone(), - from_idx, - message: ProcessedMessageEnum::Cache { message }, - }), - // This is an unprovable fault (may be a replay attack) - MessageFor::OutOfOrder => Err(Error::Remote(RemoteError { - party: from.clone(), - error: RemoteErrorEnum::OutOfOrderMessage, - })), - } - } + let message_for = self.route_message(from, &message)?; - fn verify_message_inner( - &self, - from: &Verifier, - message: SignedMessage, - ) -> Result, Error> { let verified_message = message.verify(from).map_err(|err| { Error::Remote(RemoteError { party: from.clone(), @@ -390,32 +382,67 @@ where "Verifier not found: {from:?}" ))))?; + if from_idx == self.context.party_idx { + return Err(Error::Local(LocalError( + "Cannot take a message from myself".into(), + ))); + } + + let preprocessed = PreprocessedMessage { + from_idx, + message: verified_message, + }; + + Ok(match message_for { + MessageFor::ThisRound => { + if accum.is_already_processed(&preprocessed) { + return Err(Error::Remote(RemoteError { + party: from.clone(), + error: RemoteErrorEnum::DuplicateMessage, + })); + } + Some(preprocessed) + } + MessageFor::NextRound => { + if accum.is_already_cached(&preprocessed) { + return Err(Error::Remote(RemoteError { + party: from.clone(), + error: RemoteErrorEnum::DuplicateMessage, + })); + } + accum.add_cached_message(preprocessed); + None + } + }) + } + + /// Process a received message from another party. + pub fn process_message( + &self, + preprocessed: PreprocessedMessage, + ) -> Result, Error> { + let from_idx = preprocessed.from_idx; + let from = self.context.verifiers[preprocessed.from_idx.as_usize()].clone(); + let message = preprocessed.message; match &self.tp { SessionType::Normal(round) => { - match verified_message.message_type() { + match message.message_type() { MessageType::Direct => { - let result = - round.verify_direct_message(from_idx, verified_message.payload()); - let payload = wrap_receive_result(from, result)?; + let result = round.verify_direct_message(from_idx, message.payload()); + let payload = wrap_receive_result(&from, result)?; Ok(ProcessedMessage { from: from.clone(), from_idx, - message: ProcessedMessageEnum::DmPayload { - payload, - message: verified_message, - }, + message: ProcessedMessageEnum::DmPayload { payload, message }, }) } MessageType::Broadcast => { - let result = round.verify_broadcast(from_idx, verified_message.payload()); - let payload = wrap_receive_result(from, result)?; + let result = round.verify_broadcast(from_idx, message.payload()); + let payload = wrap_receive_result(&from, result)?; Ok(ProcessedMessage { from: from.clone(), from_idx, - message: ProcessedMessageEnum::BcPayload { - payload, - message: verified_message, - }, + message: ProcessedMessageEnum::BcPayload { payload, message }, }) } _ => { @@ -429,7 +456,7 @@ where } } SessionType::Bc { bc, .. } => { - bc.verify_broadcast(from_idx, verified_message) + bc.verify_broadcast(from_idx, message) .map_err(|err| Error::Provable { party: from.clone(), error: ProvableError::Consensus(err), @@ -488,12 +515,6 @@ where } })?; - let cached_messages = accum - .cached_messages - .into_iter() - .map(|(idx, message)| (context.verifiers[idx.as_usize()].clone(), message)) - .collect(); - match outcome { type_erased::FinalizeOutcome::Success(res) => Ok(FinalizeOutcome::Success(res)), type_erased::FinalizeOutcome::AnotherRound(next_round) => { @@ -506,7 +527,7 @@ where }; Ok(FinalizeOutcome::AnotherRound { session, - cached_messages, + cached_messages: accum.cached_messages, }) } else { let session = Session { @@ -515,7 +536,7 @@ where }; Ok(FinalizeOutcome::AnotherRound { session, - cached_messages, + cached_messages: accum.cached_messages, }) } } @@ -535,12 +556,6 @@ where .finalize() .ok_or(Error::Local(LocalError("Cannot finalize".into())))?; - let cached_messages = accum - .cached_messages - .into_iter() - .map(|(idx, message)| (context.verifiers[idx.as_usize()].clone(), message)) - .collect(); - let session = Session { tp: SessionType::Normal(round), context, @@ -548,7 +563,7 @@ where Ok(FinalizeOutcome::AnotherRound { session, - cached_messages, + cached_messages: accum.cached_messages, }) } } @@ -557,7 +572,8 @@ pub struct RoundAccumulator { received_direct_messages: Vec<(PartyIdx, VerifiedMessage)>, received_broadcasts: Vec<(PartyIdx, VerifiedMessage)>, processed: DynRoundAccum, - cached_messages: Vec<(PartyIdx, SignedMessage)>, + cached_messages: Vec>, + cached_message_count: Vec, bc_accum: Option, } @@ -575,6 +591,7 @@ impl RoundAccumulator { received_broadcasts: Vec::new(), processed: DynRoundAccum::new(num_parties, party_idx, is_bc_round, is_dm_round), cached_messages: Vec::new(), + cached_message_count: vec![0; num_parties], bc_accum: if is_bc_consensus_round { Some(BcConsensusAccum::new(num_parties, party_idx)) } else { @@ -641,11 +658,6 @@ impl RoundAccumulator { }; self.received_direct_messages.push((pm.from_idx, message)); } - ProcessedMessageEnum::Cache { message } => { - // TODO: check at this stage that there are no duplicate messages, - // without waiting for the next round - self.cached_messages.push((pm.from_idx, message)); - } ProcessedMessageEnum::Bc => match &mut self.bc_accum { Some(accum) => { if accum.add_echo_received(pm.from_idx).is_none() { @@ -660,6 +672,30 @@ impl RoundAccumulator { } Ok(Ok(())) } + + fn is_already_processed(&self, preprocessed: &PreprocessedMessage) -> bool { + match preprocessed.message.message_type() { + MessageType::Direct => self.processed.contains(preprocessed.from_idx, false), + MessageType::Broadcast => self.processed.contains(preprocessed.from_idx, true), + MessageType::BroadcastConsensus => self + .bc_accum + .as_ref() + .unwrap() + .contains(preprocessed.from_idx), + } + } + + fn is_already_cached(&self, preprocessed: &PreprocessedMessage) -> bool { + // Since we don't know yet whether the next round requires two types of messages + // (direct & broadcast) or just one, we limit the cached messages with 2 per party. + // This is enough to not get DDOS'ed by messages for the next round. + self.cached_message_count[preprocessed.from_idx.as_usize()] == 2 + } + + fn add_cached_message(&mut self, preprocessed: PreprocessedMessage) { + self.cached_message_count[preprocessed.from_idx.as_usize()] += 1; + self.cached_messages.push(preprocessed); + } } pub struct Artefact { @@ -668,6 +704,11 @@ pub struct Artefact { artefact: DynDmArtefact, } +pub struct PreprocessedMessage { + from_idx: PartyIdx, + message: VerifiedMessage, +} + pub struct ProcessedMessage { from: Verifier, from_idx: PartyIdx, @@ -683,8 +724,5 @@ enum ProcessedMessageEnum { payload: DynDmPayload, message: VerifiedMessage, }, - Cache { - message: SignedMessage, - }, Bc, } diff --git a/synedrion/src/sessions/type_erased.rs b/synedrion/src/sessions/type_erased.rs index 9a255f02..55de8092 100644 --- a/synedrion/src/sessions/type_erased.rs +++ b/synedrion/src/sessions/type_erased.rs @@ -236,6 +236,24 @@ impl DynRoundAccum { } } + pub fn contains(&self, from: PartyIdx, broadcast: bool) -> bool { + if broadcast { + return self + .bc_payloads + .as_ref() + .unwrap() + .contains(from.as_usize()) + .unwrap(); + } else { + return self + .dm_payloads + .as_ref() + .unwrap() + .contains(from.as_usize()) + .unwrap(); + } + } + pub fn add_bc_payload( &mut self, from: PartyIdx, diff --git a/synedrion/src/tools/collections.rs b/synedrion/src/tools/collections.rs index a92d64d1..5588d3bb 100644 --- a/synedrion/src/tools/collections.rs +++ b/synedrion/src/tools/collections.rs @@ -24,16 +24,24 @@ impl HoleVecAccum { Self { hole_at, elems } } - pub fn get_mut(&mut self, index: usize) -> Option<&mut Option> { - if index == self.hole_at { + fn shifted_index(&self, index: usize) -> Option { + if index == self.hole_at || index > self.elems.len() { return None; } - let index = if index > self.hole_at { + Some(if index > self.hole_at { index - 1 } else { index - }; - self.elems.get_mut(index) + }) + } + + pub fn contains(&self, index: usize) -> Option { + Some(self.elems[self.shifted_index(index)?].is_some()) + } + + pub fn get_mut(&mut self, index: usize) -> Option<&mut Option> { + let idx = self.shifted_index(index)?; + self.elems.get_mut(idx) } pub fn insert(&mut self, index: usize, value: T) -> Option<()> { diff --git a/synedrion/tests/sessions.rs b/synedrion/tests/sessions.rs index 1ea82083..29b2f08c 100644 --- a/synedrion/tests/sessions.rs +++ b/synedrion/tests/sessions.rs @@ -29,7 +29,7 @@ async fn run_session( let mut rx = rx; let mut session = session; - let mut cached_messages = Vec::<(VerifyingKey, SignedMessage)>::new(); + let mut cached_messages = Vec::new(); let key = session.verifier(); let key_str = key_to_str(&key); @@ -44,9 +44,9 @@ async fn run_session( // and we don't want to bother with synchronization. let mut accum = session.make_accumulator(); - // Note: generating/sending messages, verifying cached messages, - // and verifying newly received messages can be done in parallel, - // with the results being assembled into `accum` sequentially in the host task. + // Note: generating/sending messages and verifying newly received messages + // can be done in parallel, with the results being assembled into `accum` + // sequentially in the host task. let destinations = session.broadcast_destinations(); if let Some(destinations) = destinations { @@ -82,13 +82,10 @@ async fn run_session( } } - for (from, message) in cached_messages { + for preprocessed in cached_messages { // In production usage, this will happen in a spawned task. - println!( - "{key_str}: applying a cached message from {}", - key_to_str(&from) - ); - let result = session.verify_message(&from, message).unwrap(); + println!("{key_str}: applying a cached message"); + let result = session.process_message(preprocessed).unwrap(); // This will happen in a host task. accum.add_processed_message(result).unwrap().unwrap(); @@ -98,15 +95,19 @@ async fn run_session( println!("{key_str}: waiting for a message"); let (from, message) = rx.recv().await.unwrap(); - // TODO: check here that the message from this origin hasn't been already processed - // if accum.already_processed(message) { ... } + // Perform quick checks before proceeding with the verification. + let preprocessed = session + .preprocess_message(&mut accum, &from, message) + .unwrap(); - // In production usage, this will happen in a spawned task. - println!("{key_str}: applying a message from {}", key_to_str(&from)); - let result = session.verify_message(&from, message).unwrap(); + if let Some(preprocessed) = preprocessed { + // In production usage, this will happen in a spawned task. + println!("{key_str}: applying a message from {}", key_to_str(&from)); + let result = session.process_message(preprocessed).unwrap(); - // This will happen in a host task. - accum.add_processed_message(result).unwrap().unwrap(); + // This will happen in a host task. + accum.add_processed_message(result).unwrap().unwrap(); + } } println!("{key_str}: finalizing the round");