Skip to content

Commit

Permalink
Add message preprocessing (#48)
Browse files Browse the repository at this point in the history
  • Loading branch information
fjarri authored Nov 15, 2023
2 parents 1e30507 + df7615c commit edd94af
Show file tree
Hide file tree
Showing 5 changed files with 173 additions and 102 deletions.
6 changes: 6 additions & 0 deletions synedrion/src/sessions/broadcast.rs
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,12 @@ impl BcConsensusAccum {
}
}

pub fn contains(&self, party_idx: PartyIdx) -> bool {
self.received_echo_from
.contains(party_idx.as_usize())
.unwrap()
}

pub fn add_echo_received(&mut self, from: PartyIdx) -> Option<()> {
self.received_echo_from.insert(from.as_usize(), ())
}
Expand Down
198 changes: 118 additions & 80 deletions synedrion/src/sessions/states.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
use alloc::boxed::Box;
use alloc::collections::BTreeMap;
use alloc::format;
use alloc::vec::Vec;
use alloc::{vec, vec::Vec};
use core::fmt::Debug;

use rand_core::CryptoRngCore;
Expand Down Expand Up @@ -46,13 +46,12 @@ pub struct Session<Res, Sig, Signer, Verifier> {
enum MessageFor {
ThisRound,
NextRound,
OutOfOrder,
}

fn route_message_normal<Res: ProtocolResult, Sig>(
round: &dyn DynFinalizable<Res>,
message: &SignedMessage<Sig>,
) -> MessageFor {
) -> Result<MessageFor, RemoteErrorEnum> {
let this_round = round.round_num();
let next_round = round.next_round_num();
let requires_bc = round.requires_broadcast_consensus();
Expand All @@ -61,7 +60,7 @@ fn route_message_normal<Res: ProtocolResult, Sig>(
let message_bc = message.message_type() == MessageType::BroadcastConsensus;

if message_round == this_round && !message_bc {
return MessageFor::ThisRound;
return Ok(MessageFor::ThisRound);
}

let for_next_round =
Expand All @@ -71,29 +70,29 @@ fn route_message_normal<Res: ProtocolResult, Sig>(
(requires_bc && message_round == this_round && message_bc);

if for_next_round {
return MessageFor::NextRound;
return Ok(MessageFor::NextRound);
}

MessageFor::OutOfOrder
Err(RemoteErrorEnum::OutOfOrderMessage)
}

fn route_message_bc<Res: ProtocolResult, Sig>(
next_round: &dyn DynFinalizable<Res>,
message: &SignedMessage<Sig>,
) -> MessageFor {
) -> Result<MessageFor, RemoteErrorEnum> {
let next_round = next_round.round_num();
let message_round = message.round();
let message_bc = message.message_type() == MessageType::BroadcastConsensus;

if message_round == next_round - 1 && message_bc {
return MessageFor::ThisRound;
return Ok(MessageFor::ThisRound);
}

if message_round == next_round && !message_bc {
return MessageFor::NextRound;
return Ok(MessageFor::NextRound);
}

MessageFor::OutOfOrder
Err(RemoteErrorEnum::OutOfOrderMessage)
}

fn wrap_receive_result<Res: ProtocolResult, Verifier: Clone, T>(
Expand Down Expand Up @@ -129,7 +128,7 @@ pub enum FinalizeOutcome<Res: ProtocolResult, Sig, Signer, Verifier> {
/// The new session object.
session: Session<Res, Sig, Signer, Verifier>,
/// The messages for the new round received during the previous round.
cached_messages: Vec<(Verifier, SignedMessage<Sig>)>,
cached_messages: Vec<PreprocessedMessage<Sig>>,
},
}

Expand Down Expand Up @@ -333,12 +332,31 @@ where
}
}

/// Process a received message from another party.
pub fn verify_message(
fn route_message(
&self,
from: &Verifier,
message: &SignedMessage<Sig>,
) -> Result<MessageFor, Error<Res, Verifier>> {
let message_for = match &self.tp {
SessionType::Normal(round) => route_message_normal(round.as_ref(), message),
SessionType::Bc { next_round, .. } => route_message_bc(next_round.as_ref(), message),
};

message_for.map_err(|err| {
Error::Remote(RemoteError {
party: from.clone(),
error: err,
})
})
}

/// Perform quick checks on a received message.
pub fn preprocess_message(
&self,
accum: &mut RoundAccumulator<Sig>,
from: &Verifier,
message: SignedMessage<Sig>,
) -> Result<ProcessedMessage<Sig, Verifier>, Error<Res, Verifier>> {
) -> Result<Option<PreprocessedMessage<Sig>>, Error<Res, Verifier>> {
// This is an unprovable fault (may be a replay attack)
if message.session_id() != &self.context.session_id {
return Err(Error::Remote(RemoteError {
Expand All @@ -347,34 +365,8 @@ where
}));
}

let message_for = match &self.tp {
SessionType::Normal(round) => route_message_normal(round.as_ref(), &message),
SessionType::Bc { next_round, .. } => route_message_bc(next_round.as_ref(), &message),
};

let from_idx = self.context.verifier_to_idx[from];

match message_for {
MessageFor::ThisRound => self.verify_message_inner(from, message),
// TODO: should we cache the verified or the unverified message?
MessageFor::NextRound => Ok(ProcessedMessage {
from: from.clone(),
from_idx,
message: ProcessedMessageEnum::Cache { message },
}),
// This is an unprovable fault (may be a replay attack)
MessageFor::OutOfOrder => Err(Error::Remote(RemoteError {
party: from.clone(),
error: RemoteErrorEnum::OutOfOrderMessage,
})),
}
}
let message_for = self.route_message(from, &message)?;

fn verify_message_inner(
&self,
from: &Verifier,
message: SignedMessage<Sig>,
) -> Result<ProcessedMessage<Sig, Verifier>, Error<Res, Verifier>> {
let verified_message = message.verify(from).map_err(|err| {
Error::Remote(RemoteError {
party: from.clone(),
Expand All @@ -390,32 +382,67 @@ where
"Verifier not found: {from:?}"
))))?;

if from_idx == self.context.party_idx {
return Err(Error::Local(LocalError(
"Cannot take a message from myself".into(),
)));
}

let preprocessed = PreprocessedMessage {
from_idx,
message: verified_message,
};

Ok(match message_for {
MessageFor::ThisRound => {
if accum.is_already_processed(&preprocessed) {
return Err(Error::Remote(RemoteError {
party: from.clone(),
error: RemoteErrorEnum::DuplicateMessage,
}));
}
Some(preprocessed)
}
MessageFor::NextRound => {
if accum.is_already_cached(&preprocessed) {
return Err(Error::Remote(RemoteError {
party: from.clone(),
error: RemoteErrorEnum::DuplicateMessage,
}));
}
accum.add_cached_message(preprocessed);
None
}
})
}

/// Process a received message from another party.
pub fn process_message(
&self,
preprocessed: PreprocessedMessage<Sig>,
) -> Result<ProcessedMessage<Sig, Verifier>, Error<Res, Verifier>> {
let from_idx = preprocessed.from_idx;
let from = self.context.verifiers[preprocessed.from_idx.as_usize()].clone();
let message = preprocessed.message;
match &self.tp {
SessionType::Normal(round) => {
match verified_message.message_type() {
match message.message_type() {
MessageType::Direct => {
let result =
round.verify_direct_message(from_idx, verified_message.payload());
let payload = wrap_receive_result(from, result)?;
let result = round.verify_direct_message(from_idx, message.payload());
let payload = wrap_receive_result(&from, result)?;
Ok(ProcessedMessage {
from: from.clone(),
from_idx,
message: ProcessedMessageEnum::DmPayload {
payload,
message: verified_message,
},
message: ProcessedMessageEnum::DmPayload { payload, message },
})
}
MessageType::Broadcast => {
let result = round.verify_broadcast(from_idx, verified_message.payload());
let payload = wrap_receive_result(from, result)?;
let result = round.verify_broadcast(from_idx, message.payload());
let payload = wrap_receive_result(&from, result)?;
Ok(ProcessedMessage {
from: from.clone(),
from_idx,
message: ProcessedMessageEnum::BcPayload {
payload,
message: verified_message,
},
message: ProcessedMessageEnum::BcPayload { payload, message },
})
}
_ => {
Expand All @@ -429,7 +456,7 @@ where
}
}
SessionType::Bc { bc, .. } => {
bc.verify_broadcast(from_idx, verified_message)
bc.verify_broadcast(from_idx, message)
.map_err(|err| Error::Provable {
party: from.clone(),
error: ProvableError::Consensus(err),
Expand Down Expand Up @@ -488,12 +515,6 @@ where
}
})?;

let cached_messages = accum
.cached_messages
.into_iter()
.map(|(idx, message)| (context.verifiers[idx.as_usize()].clone(), message))
.collect();

match outcome {
type_erased::FinalizeOutcome::Success(res) => Ok(FinalizeOutcome::Success(res)),
type_erased::FinalizeOutcome::AnotherRound(next_round) => {
Expand All @@ -506,7 +527,7 @@ where
};
Ok(FinalizeOutcome::AnotherRound {
session,
cached_messages,
cached_messages: accum.cached_messages,
})
} else {
let session = Session {
Expand All @@ -515,7 +536,7 @@ where
};
Ok(FinalizeOutcome::AnotherRound {
session,
cached_messages,
cached_messages: accum.cached_messages,
})
}
}
Expand All @@ -535,20 +556,14 @@ where
.finalize()
.ok_or(Error::Local(LocalError("Cannot finalize".into())))?;

let cached_messages = accum
.cached_messages
.into_iter()
.map(|(idx, message)| (context.verifiers[idx.as_usize()].clone(), message))
.collect();

let session = Session {
tp: SessionType::Normal(round),
context,
};

Ok(FinalizeOutcome::AnotherRound {
session,
cached_messages,
cached_messages: accum.cached_messages,
})
}
}
Expand All @@ -557,7 +572,8 @@ pub struct RoundAccumulator<Sig> {
received_direct_messages: Vec<(PartyIdx, VerifiedMessage<Sig>)>,
received_broadcasts: Vec<(PartyIdx, VerifiedMessage<Sig>)>,
processed: DynRoundAccum,
cached_messages: Vec<(PartyIdx, SignedMessage<Sig>)>,
cached_messages: Vec<PreprocessedMessage<Sig>>,
cached_message_count: Vec<usize>,
bc_accum: Option<BcConsensusAccum>,
}

Expand All @@ -575,6 +591,7 @@ impl<Sig> RoundAccumulator<Sig> {
received_broadcasts: Vec::new(),
processed: DynRoundAccum::new(num_parties, party_idx, is_bc_round, is_dm_round),
cached_messages: Vec::new(),
cached_message_count: vec![0; num_parties],
bc_accum: if is_bc_consensus_round {
Some(BcConsensusAccum::new(num_parties, party_idx))
} else {
Expand Down Expand Up @@ -641,11 +658,6 @@ impl<Sig> RoundAccumulator<Sig> {
};
self.received_direct_messages.push((pm.from_idx, message));
}
ProcessedMessageEnum::Cache { message } => {
// TODO: check at this stage that there are no duplicate messages,
// without waiting for the next round
self.cached_messages.push((pm.from_idx, message));
}
ProcessedMessageEnum::Bc => match &mut self.bc_accum {
Some(accum) => {
if accum.add_echo_received(pm.from_idx).is_none() {
Expand All @@ -660,6 +672,30 @@ impl<Sig> RoundAccumulator<Sig> {
}
Ok(Ok(()))
}

fn is_already_processed(&self, preprocessed: &PreprocessedMessage<Sig>) -> bool {
match preprocessed.message.message_type() {
MessageType::Direct => self.processed.contains(preprocessed.from_idx, false),
MessageType::Broadcast => self.processed.contains(preprocessed.from_idx, true),
MessageType::BroadcastConsensus => self
.bc_accum
.as_ref()
.unwrap()
.contains(preprocessed.from_idx),
}
}

fn is_already_cached(&self, preprocessed: &PreprocessedMessage<Sig>) -> bool {
// Since we don't know yet whether the next round requires two types of messages
// (direct & broadcast) or just one, we limit the cached messages with 2 per party.
// This is enough to not get DDOS'ed by messages for the next round.
self.cached_message_count[preprocessed.from_idx.as_usize()] == 2
}

fn add_cached_message(&mut self, preprocessed: PreprocessedMessage<Sig>) {
self.cached_message_count[preprocessed.from_idx.as_usize()] += 1;
self.cached_messages.push(preprocessed);
}
}

pub struct Artefact<Verifier> {
Expand All @@ -668,6 +704,11 @@ pub struct Artefact<Verifier> {
artefact: DynDmArtefact,
}

pub struct PreprocessedMessage<Sig> {
from_idx: PartyIdx,
message: VerifiedMessage<Sig>,
}

pub struct ProcessedMessage<Sig, Verifier> {
from: Verifier,
from_idx: PartyIdx,
Expand All @@ -683,8 +724,5 @@ enum ProcessedMessageEnum<Sig> {
payload: DynDmPayload,
message: VerifiedMessage<Sig>,
},
Cache {
message: SignedMessage<Sig>,
},
Bc,
}
Loading

0 comments on commit edd94af

Please sign in to comment.