Skip to content

Commit

Permalink
Simplify message serialization to avoid writing out the protocol name
Browse files Browse the repository at this point in the history
  • Loading branch information
fjarri committed Oct 13, 2024
1 parent 1a4f52d commit 0bab9fe
Show file tree
Hide file tree
Showing 5 changed files with 31 additions and 22 deletions.
18 changes: 9 additions & 9 deletions example/src/simple.rs
Original file line number Diff line number Diff line change
Expand Up @@ -45,20 +45,20 @@ impl ProtocolError for SimpleProtocolError {
) -> Result<(), ProtocolValidationError> {
match self {
SimpleProtocolError::Round1InvalidPosition => {
let _message = direct_message.try_deserialize::<SimpleProtocol, Round1Message>()?;
let _message = direct_message.deserialize::<SimpleProtocol, Round1Message>()?;
// Message contents would be checked here
Ok(())
}
SimpleProtocolError::Round2InvalidPosition => {
let _r1_message = direct_message.try_deserialize::<SimpleProtocol, Round1Message>()?;
let _r1_message = direct_message.deserialize::<SimpleProtocol, Round1Message>()?;
let r1_echos_serialized = combined_echos
.get(&RoundId::new(1))
.ok_or_else(|| LocalError::new("Could not find combined echos for Round 1"))?;

// Deserialize the echos
let _r1_echos = r1_echos_serialized
.iter()
.map(|echo| echo.try_deserialize::<SimpleProtocol, Round1Echo>())
.map(|echo| echo.deserialize::<SimpleProtocol, Round1Echo>())
.collect::<Result<Vec<_>, _>>()?;

// Message contents would be checked here
Expand Down Expand Up @@ -179,7 +179,7 @@ impl<Id: 'static + Debug + Clone + Ord + Send + Sync> Round<Id> for Round1<Id> {
my_position: self.context.ids_to_positions[&self.context.id],
};

Some(EchoBroadcast::new::<SimpleProtocol, _>(&message))
Some(Self::serialize_echo_broadcast(message))
}

fn make_direct_message(
Expand All @@ -193,7 +193,7 @@ impl<Id: 'static + Debug + Clone + Ord + Send + Sync> Round<Id> for Round1<Id> {
my_position: self.context.ids_to_positions[&self.context.id],
your_position: self.context.ids_to_positions[destination],
};
let dm = DirectMessage::new::<SimpleProtocol, _>(&message)?;
let dm = Self::serialize_direct_message(message)?;
let artifact = Artifact::empty();
Ok((dm, artifact))
}
Expand All @@ -207,7 +207,7 @@ impl<Id: 'static + Debug + Clone + Ord + Send + Sync> Round<Id> for Round1<Id> {
) -> Result<Payload, ReceiveError<Id, Self::Protocol>> {
debug!("{:?}: receiving message from {:?}", self.context.id, from);

let message = direct_message.try_deserialize::<SimpleProtocol, Round1Message>()?;
let message = direct_message.deserialize::<SimpleProtocol, Round1Message>()?;

debug!("{:?}: received message: {:?}", self.context.id, message);

Expand Down Expand Up @@ -283,7 +283,7 @@ impl<Id: 'static + Debug + Clone + Ord + Send + Sync> Round<Id> for Round2<Id> {
my_position: self.context.ids_to_positions[&self.context.id],
};

Some(EchoBroadcast::new::<SimpleProtocol, _>(&message))
Some(Self::serialize_echo_broadcast(message))
}

fn make_direct_message(
Expand All @@ -297,7 +297,7 @@ impl<Id: 'static + Debug + Clone + Ord + Send + Sync> Round<Id> for Round2<Id> {
my_position: self.context.ids_to_positions[&self.context.id],
your_position: self.context.ids_to_positions[destination],
};
let dm = DirectMessage::new::<SimpleProtocol, _>(&message)?;
let dm = Self::serialize_direct_message(message)?;
let artifact = Artifact::empty();
Ok((dm, artifact))
}
Expand All @@ -311,7 +311,7 @@ impl<Id: 'static + Debug + Clone + Ord + Send + Sync> Round<Id> for Round2<Id> {
) -> Result<Payload, ReceiveError<Id, Self::Protocol>> {
debug!("{:?}: receiving message from {:?}", self.context.id, from);

let message = direct_message.try_deserialize::<SimpleProtocol, Round1Message>()?;
let message = direct_message.deserialize::<SimpleProtocol, Round1Message>()?;

debug!("{:?}: received message: {:?}", self.context.id, message);

Expand Down
20 changes: 14 additions & 6 deletions manul/src/protocol/round.rs
Original file line number Diff line number Diff line change
Expand Up @@ -291,12 +291,12 @@ pub struct EchoBroadcastError(DeserializationError);
pub struct DirectMessage(#[serde(with = "serde_bytes")] Box<[u8]>);

impl DirectMessage {
pub fn new<P: Protocol, T: Serialize>(message: &T) -> Result<Self, LocalError> {
pub fn new<P: Protocol, T: Serialize>(message: T) -> Result<Self, LocalError> {
P::serialize(message).map(Self)
}

pub fn verify_is_invalid<P: Protocol, T: for<'de> Deserialize<'de>>(&self) -> Result<(), MessageValidationError> {
if self.try_deserialize::<P, T>().is_err() {
if self.deserialize::<P, T>().is_err() {
Ok(())
} else {
Err(MessageValidationError::Other(
Expand All @@ -305,7 +305,7 @@ impl DirectMessage {
}
}

pub fn try_deserialize<P: Protocol, T: for<'de> Deserialize<'de>>(&self) -> Result<T, DirectMessageError> {
pub fn deserialize<P: Protocol, T: for<'de> Deserialize<'de>>(&self) -> Result<T, DirectMessageError> {
P::deserialize(&self.0).map_err(DirectMessageError)
}
}
Expand All @@ -314,12 +314,12 @@ impl DirectMessage {
pub struct EchoBroadcast(#[serde(with = "serde_bytes")] Box<[u8]>);

impl EchoBroadcast {
pub fn new<P: Protocol, T: Serialize>(message: &T) -> Result<Self, LocalError> {
pub fn new<P: Protocol, T: Serialize>(message: T) -> Result<Self, LocalError> {
P::serialize(message).map(Self)
}

pub fn verify_is_invalid<P: Protocol, T: for<'de> Deserialize<'de>>(&self) -> Result<(), MessageValidationError> {
if self.try_deserialize::<P, T>().is_err() {
if self.deserialize::<P, T>().is_err() {
Ok(())
} else {
Err(MessageValidationError::Other(
Expand All @@ -328,7 +328,7 @@ impl EchoBroadcast {
}
}

pub fn try_deserialize<P: Protocol, T: for<'de> Deserialize<'de>>(&self) -> Result<T, EchoBroadcastError> {
pub fn deserialize<P: Protocol, T: for<'de> Deserialize<'de>>(&self) -> Result<T, EchoBroadcastError> {
P::deserialize(&self.0).map_err(EchoBroadcastError)
}
}
Expand Down Expand Up @@ -419,4 +419,12 @@ pub trait Round<Id>: 'static + Send + Sync {
) -> Result<FinalizeOutcome<Id, Self::Protocol>, FinalizeError<Id, Self::Protocol>>;

fn expecting_messages_from(&self) -> &BTreeSet<Id>;

fn serialize_echo_broadcast(message: impl Serialize) -> Result<EchoBroadcast, LocalError> {
EchoBroadcast::new::<Self::Protocol, _>(message)
}

fn serialize_direct_message(message: impl Serialize) -> Result<DirectMessage, LocalError> {
DirectMessage::new::<Self::Protocol, _>(message)
}
}
2 changes: 1 addition & 1 deletion manul/src/session/echo.rs
Original file line number Diff line number Diff line change
Expand Up @@ -139,7 +139,7 @@ where
) -> Result<Payload, ReceiveError<Id, Self::Protocol>> {
debug!("{:?}: received an echo message from {:?}", self.verifier, from);

let message = direct_message.try_deserialize::<P, EchoRoundMessage<Id, S>>()?;
let message = direct_message.deserialize::<P, EchoRoundMessage<Id, S>>()?;

// Check that the received message contains entries from `destinations` sans `from`
// It is an unprovable fault.
Expand Down
8 changes: 3 additions & 5 deletions manul/src/session/evidence.rs
Original file line number Diff line number Diff line change
Expand Up @@ -146,7 +146,7 @@ where

let deserialized = direct_message
.payload()
.try_deserialize::<P, EchoRoundMessage<Verifier, S>>()
.deserialize::<P, EchoRoundMessage<Verifier, S>>()
.map_err(|error| {
LocalError::new(format!("Failed to deserialize the given direct message: {:?}", error))
})?;
Expand Down Expand Up @@ -242,9 +242,7 @@ where
{
fn verify(&self, verifier: &Verifier) -> Result<(), EvidenceError> {
let verified = self.direct_message.clone().verify::<P, _>(verifier)?;
let deserialized = verified
.payload()
.try_deserialize::<P, EchoRoundMessage<Verifier, S>>()?;
let deserialized = verified.payload().deserialize::<P, EchoRoundMessage<Verifier, S>>()?;
let invalid_echo = deserialized
.echo_messages
.get(&self.invalid_echo_sender)
Expand Down Expand Up @@ -418,7 +416,7 @@ where
));
}
let echo_set =
DirectMessage::try_deserialize::<P, EchoRoundMessage<Verifier, S>>(verified_combined_echo.payload())?;
DirectMessage::deserialize::<P, EchoRoundMessage<Verifier, S>>(verified_combined_echo.payload())?;

let mut verified_echo_set = Vec::new();
for (other_verifier, echo_broadcast) in echo_set.echo_messages.iter() {
Expand Down
5 changes: 4 additions & 1 deletion manul/src/testing/macros.rs
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,10 @@ macro_rules! round_override {
rng: &mut impl CryptoRngCore,
payloads: ::alloc::collections::BTreeMap<Id, $crate::protocol::Payload>,
artifacts: ::alloc::collections::BTreeMap<Id, $crate::protocol::Artifact>,
) -> Result<$crate::protocol::FinalizeOutcome<Id, Self::Protocol>, $crate::protocol::FinalizeError<Id, Self::Protocol>> {
) -> Result<
$crate::protocol::FinalizeOutcome<Id, Self::Protocol>,
$crate::protocol::FinalizeError<Id, Self::Protocol>
> {
<Self as RoundOverride<Id>>::finalize(self, rng, payloads, artifacts)
}

Expand Down

0 comments on commit 0bab9fe

Please sign in to comment.