Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add message preprocessing #48

Merged
merged 1 commit into from
Nov 15, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions synedrion/src/sessions/broadcast.rs
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,12 @@ impl BcConsensusAccum {
}
}

pub fn contains(&self, party_idx: PartyIdx) -> bool {
self.received_echo_from
.contains(party_idx.as_usize())
.unwrap()
}

pub fn add_echo_received(&mut self, from: PartyIdx) -> Option<()> {
self.received_echo_from.insert(from.as_usize(), ())
}
Expand Down
198 changes: 118 additions & 80 deletions synedrion/src/sessions/states.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
use alloc::boxed::Box;
use alloc::collections::BTreeMap;
use alloc::format;
use alloc::vec::Vec;
use alloc::{vec, vec::Vec};
use core::fmt::Debug;

use rand_core::CryptoRngCore;
Expand Down Expand Up @@ -46,13 +46,12 @@ pub struct Session<Res, Sig, Signer, Verifier> {
enum MessageFor {
ThisRound,
NextRound,
OutOfOrder,
}

fn route_message_normal<Res: ProtocolResult, Sig>(
round: &dyn DynFinalizable<Res>,
message: &SignedMessage<Sig>,
) -> MessageFor {
) -> Result<MessageFor, RemoteErrorEnum> {
let this_round = round.round_num();
let next_round = round.next_round_num();
let requires_bc = round.requires_broadcast_consensus();
Expand All @@ -61,7 +60,7 @@ fn route_message_normal<Res: ProtocolResult, Sig>(
let message_bc = message.message_type() == MessageType::BroadcastConsensus;

if message_round == this_round && !message_bc {
return MessageFor::ThisRound;
return Ok(MessageFor::ThisRound);
}

let for_next_round =
Expand All @@ -71,29 +70,29 @@ fn route_message_normal<Res: ProtocolResult, Sig>(
(requires_bc && message_round == this_round && message_bc);

if for_next_round {
return MessageFor::NextRound;
return Ok(MessageFor::NextRound);
}

MessageFor::OutOfOrder
Err(RemoteErrorEnum::OutOfOrderMessage)
}

fn route_message_bc<Res: ProtocolResult, Sig>(
next_round: &dyn DynFinalizable<Res>,
message: &SignedMessage<Sig>,
) -> MessageFor {
) -> Result<MessageFor, RemoteErrorEnum> {
let next_round = next_round.round_num();
let message_round = message.round();
let message_bc = message.message_type() == MessageType::BroadcastConsensus;

if message_round == next_round - 1 && message_bc {
return MessageFor::ThisRound;
return Ok(MessageFor::ThisRound);
}

if message_round == next_round && !message_bc {
return MessageFor::NextRound;
return Ok(MessageFor::NextRound);
}

MessageFor::OutOfOrder
Err(RemoteErrorEnum::OutOfOrderMessage)
}

fn wrap_receive_result<Res: ProtocolResult, Verifier: Clone, T>(
Expand Down Expand Up @@ -129,7 +128,7 @@ pub enum FinalizeOutcome<Res: ProtocolResult, Sig, Signer, Verifier> {
/// The new session object.
session: Session<Res, Sig, Signer, Verifier>,
/// The messages for the new round received during the previous round.
cached_messages: Vec<(Verifier, SignedMessage<Sig>)>,
cached_messages: Vec<PreprocessedMessage<Sig>>,
},
}

Expand Down Expand Up @@ -333,12 +332,31 @@ where
}
}

/// Process a received message from another party.
pub fn verify_message(
fn route_message(
&self,
from: &Verifier,
message: &SignedMessage<Sig>,
) -> Result<MessageFor, Error<Res, Verifier>> {
let message_for = match &self.tp {
SessionType::Normal(round) => route_message_normal(round.as_ref(), message),
SessionType::Bc { next_round, .. } => route_message_bc(next_round.as_ref(), message),
};

message_for.map_err(|err| {
Error::Remote(RemoteError {
party: from.clone(),
error: err,
})
})
}

/// Perform quick checks on a received message.
pub fn preprocess_message(
&self,
accum: &mut RoundAccumulator<Sig>,
from: &Verifier,
message: SignedMessage<Sig>,
) -> Result<ProcessedMessage<Sig, Verifier>, Error<Res, Verifier>> {
) -> Result<Option<PreprocessedMessage<Sig>>, Error<Res, Verifier>> {
// This is an unprovable fault (may be a replay attack)
if message.session_id() != &self.context.session_id {
return Err(Error::Remote(RemoteError {
Expand All @@ -347,34 +365,8 @@ where
}));
}

let message_for = match &self.tp {
SessionType::Normal(round) => route_message_normal(round.as_ref(), &message),
SessionType::Bc { next_round, .. } => route_message_bc(next_round.as_ref(), &message),
};

let from_idx = self.context.verifier_to_idx[from];

match message_for {
MessageFor::ThisRound => self.verify_message_inner(from, message),
// TODO: should we cache the verified or the unverified message?
MessageFor::NextRound => Ok(ProcessedMessage {
from: from.clone(),
from_idx,
message: ProcessedMessageEnum::Cache { message },
}),
// This is an unprovable fault (may be a replay attack)
MessageFor::OutOfOrder => Err(Error::Remote(RemoteError {
party: from.clone(),
error: RemoteErrorEnum::OutOfOrderMessage,
})),
}
}
let message_for = self.route_message(from, &message)?;

fn verify_message_inner(
&self,
from: &Verifier,
message: SignedMessage<Sig>,
) -> Result<ProcessedMessage<Sig, Verifier>, Error<Res, Verifier>> {
let verified_message = message.verify(from).map_err(|err| {
Error::Remote(RemoteError {
party: from.clone(),
Expand All @@ -390,32 +382,67 @@ where
"Verifier not found: {from:?}"
))))?;

if from_idx == self.context.party_idx {
return Err(Error::Local(LocalError(
"Cannot take a message from myself".into(),
)));
}

let preprocessed = PreprocessedMessage {
from_idx,
message: verified_message,
};

Ok(match message_for {
MessageFor::ThisRound => {
if accum.is_already_processed(&preprocessed) {
return Err(Error::Remote(RemoteError {
party: from.clone(),
error: RemoteErrorEnum::DuplicateMessage,
}));
}
Some(preprocessed)
}
MessageFor::NextRound => {
if accum.is_already_cached(&preprocessed) {
return Err(Error::Remote(RemoteError {
party: from.clone(),
error: RemoteErrorEnum::DuplicateMessage,
}));
}
accum.add_cached_message(preprocessed);
None
}
})
}

/// Process a received message from another party.
pub fn process_message(
&self,
preprocessed: PreprocessedMessage<Sig>,
) -> Result<ProcessedMessage<Sig, Verifier>, Error<Res, Verifier>> {
let from_idx = preprocessed.from_idx;
let from = self.context.verifiers[preprocessed.from_idx.as_usize()].clone();
let message = preprocessed.message;
match &self.tp {
SessionType::Normal(round) => {
match verified_message.message_type() {
match message.message_type() {
MessageType::Direct => {
let result =
round.verify_direct_message(from_idx, verified_message.payload());
let payload = wrap_receive_result(from, result)?;
let result = round.verify_direct_message(from_idx, message.payload());
let payload = wrap_receive_result(&from, result)?;
Ok(ProcessedMessage {
from: from.clone(),
from_idx,
message: ProcessedMessageEnum::DmPayload {
payload,
message: verified_message,
},
message: ProcessedMessageEnum::DmPayload { payload, message },
})
}
MessageType::Broadcast => {
let result = round.verify_broadcast(from_idx, verified_message.payload());
let payload = wrap_receive_result(from, result)?;
let result = round.verify_broadcast(from_idx, message.payload());
let payload = wrap_receive_result(&from, result)?;
Ok(ProcessedMessage {
from: from.clone(),
from_idx,
message: ProcessedMessageEnum::BcPayload {
payload,
message: verified_message,
},
message: ProcessedMessageEnum::BcPayload { payload, message },
})
}
_ => {
Expand All @@ -429,7 +456,7 @@ where
}
}
SessionType::Bc { bc, .. } => {
bc.verify_broadcast(from_idx, verified_message)
bc.verify_broadcast(from_idx, message)
.map_err(|err| Error::Provable {
party: from.clone(),
error: ProvableError::Consensus(err),
Expand Down Expand Up @@ -488,12 +515,6 @@ where
}
})?;

let cached_messages = accum
.cached_messages
.into_iter()
.map(|(idx, message)| (context.verifiers[idx.as_usize()].clone(), message))
.collect();

match outcome {
type_erased::FinalizeOutcome::Success(res) => Ok(FinalizeOutcome::Success(res)),
type_erased::FinalizeOutcome::AnotherRound(next_round) => {
Expand All @@ -506,7 +527,7 @@ where
};
Ok(FinalizeOutcome::AnotherRound {
session,
cached_messages,
cached_messages: accum.cached_messages,
})
} else {
let session = Session {
Expand All @@ -515,7 +536,7 @@ where
};
Ok(FinalizeOutcome::AnotherRound {
session,
cached_messages,
cached_messages: accum.cached_messages,
})
}
}
Expand All @@ -535,20 +556,14 @@ where
.finalize()
.ok_or(Error::Local(LocalError("Cannot finalize".into())))?;

let cached_messages = accum
.cached_messages
.into_iter()
.map(|(idx, message)| (context.verifiers[idx.as_usize()].clone(), message))
.collect();

let session = Session {
tp: SessionType::Normal(round),
context,
};

Ok(FinalizeOutcome::AnotherRound {
session,
cached_messages,
cached_messages: accum.cached_messages,
})
}
}
Expand All @@ -557,7 +572,8 @@ pub struct RoundAccumulator<Sig> {
received_direct_messages: Vec<(PartyIdx, VerifiedMessage<Sig>)>,
received_broadcasts: Vec<(PartyIdx, VerifiedMessage<Sig>)>,
processed: DynRoundAccum,
cached_messages: Vec<(PartyIdx, SignedMessage<Sig>)>,
cached_messages: Vec<PreprocessedMessage<Sig>>,
cached_message_count: Vec<usize>,
bc_accum: Option<BcConsensusAccum>,
}

Expand All @@ -575,6 +591,7 @@ impl<Sig> RoundAccumulator<Sig> {
received_broadcasts: Vec::new(),
processed: DynRoundAccum::new(num_parties, party_idx, is_bc_round, is_dm_round),
cached_messages: Vec::new(),
cached_message_count: vec![0; num_parties],
bc_accum: if is_bc_consensus_round {
Some(BcConsensusAccum::new(num_parties, party_idx))
} else {
Expand Down Expand Up @@ -641,11 +658,6 @@ impl<Sig> RoundAccumulator<Sig> {
};
self.received_direct_messages.push((pm.from_idx, message));
}
ProcessedMessageEnum::Cache { message } => {
// TODO: check at this stage that there are no duplicate messages,
// without waiting for the next round
self.cached_messages.push((pm.from_idx, message));
}
ProcessedMessageEnum::Bc => match &mut self.bc_accum {
Some(accum) => {
if accum.add_echo_received(pm.from_idx).is_none() {
Expand All @@ -660,6 +672,30 @@ impl<Sig> RoundAccumulator<Sig> {
}
Ok(Ok(()))
}

fn is_already_processed(&self, preprocessed: &PreprocessedMessage<Sig>) -> bool {
match preprocessed.message.message_type() {
MessageType::Direct => self.processed.contains(preprocessed.from_idx, false),
MessageType::Broadcast => self.processed.contains(preprocessed.from_idx, true),
MessageType::BroadcastConsensus => self
.bc_accum
.as_ref()
.unwrap()
.contains(preprocessed.from_idx),
}
}

fn is_already_cached(&self, preprocessed: &PreprocessedMessage<Sig>) -> bool {
// Since we don't know yet whether the next round requires two types of messages
// (direct & broadcast) or just one, we limit the cached messages with 2 per party.
// This is enough to not get DDOS'ed by messages for the next round.
self.cached_message_count[preprocessed.from_idx.as_usize()] == 2
}

fn add_cached_message(&mut self, preprocessed: PreprocessedMessage<Sig>) {
self.cached_message_count[preprocessed.from_idx.as_usize()] += 1;
self.cached_messages.push(preprocessed);
}
}

pub struct Artefact<Verifier> {
Expand All @@ -668,6 +704,11 @@ pub struct Artefact<Verifier> {
artefact: DynDmArtefact,
}

pub struct PreprocessedMessage<Sig> {
from_idx: PartyIdx,
message: VerifiedMessage<Sig>,
}

pub struct ProcessedMessage<Sig, Verifier> {
from: Verifier,
from_idx: PartyIdx,
Expand All @@ -683,8 +724,5 @@ enum ProcessedMessageEnum<Sig> {
payload: DynDmPayload,
message: VerifiedMessage<Sig>,
},
Cache {
message: SignedMessage<Sig>,
},
Bc,
}
Loading
Loading